Lecture 8. Neural Networks
Contents
Lecture 8. Neural Networks¶
How to train your neurons
Joaquin Vanschoren
# Note: You'll need to install tensorflow-addons. One of these should work
# !pip install tensorflow_addons
# !pip install tfa-nightly
# TODO: Fix issues running Cyclical Learning rate and AdaMax with latest TF
# Auto-setup when running on Google Colab
import os
if 'google.colab' in str(get_ipython()) and not os.path.exists('/content/master'):
!git clone -q https://github.com/ML-course/master.git /content/master
!pip install -rq master/requirements_colab.txt
%cd master/notebooks
# Global imports and settings
%matplotlib inline
from preamble import *
interactive = False # Set to True for interactive plots
if interactive:
fig_scale = 0.7
plt.rcParams.update(print_config)
else: # For printing
fig_scale = 0.8
plt.rcParams.update(print_config)
Overview¶
Neural architectures
Training neural nets
Forward pass: Tensor operations
Backward pass: Backpropagation
Neural network design:
Activation functions
Weight initialization
Optimizers
Neural networks in practice
Model selection
Early stopping
Memorization capacity and information bottleneck
L1/L2 regularization
Dropout
Batch normalization
def draw_neural_net(ax, layer_sizes, draw_bias=False, labels=False, activation=False, sigmoid=False,
weight_count=False, random_weights=False, show_activations=False, figsize=(4, 4)):
"""
Draws a dense neural net for educational purposes
Parameters:
ax: plot axis
layer_sizes: array with the sizes of every layer
draw_bias: whether to draw bias nodes
labels: whether to draw labels for the weights and nodes
activation: whether to show the activation function inside the nodes
sigmoid: whether the last activation function is a sigmoid
weight_count: whether to show the number of weights and biases
random_weights: whether to show random weights as colored lines
show_activations: whether to show a variable for the node activations
scale_ratio: ratio of the plot dimensions, e.g. 3/4
"""
left, right, bottom, top = 0.1, 0.89*figsize[0]/figsize[1], 0.1, 0.89
n_layers = len(layer_sizes)
v_spacing = (top - bottom)/float(max(layer_sizes))
h_spacing = (right - left)/float(len(layer_sizes) - 1)
colors = ['greenyellow','cornflowerblue','lightcoral']
w_count, b_count = 0, 0
ax.set_xlim(0, figsize[0]/figsize[1])
ax.axis('off')
ax.set_aspect('equal')
txtargs = {"fontsize":12, "verticalalignment":'center', "horizontalalignment":'center', "zorder":5}
# Draw biases by adding a node to every layer except the last one
if draw_bias:
layer_sizes = [x+1 for x in layer_sizes]
layer_sizes[-1] = layer_sizes[-1] - 1
# Nodes
for n, layer_size in enumerate(layer_sizes):
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
node_size = v_spacing/len(layer_sizes) if activation and n!=0 else v_spacing/3.
if n==0:
color = colors[0]
elif n==len(layer_sizes)-1:
color = colors[2]
else:
color = colors[1]
for m in range(layer_size):
ax.add_artist(plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), radius=node_size,
color=color, ec='k', zorder=4))
b_count += 1
nx, ny = n*h_spacing + left, layer_top - m*v_spacing
nsx, nsy = [n*h_spacing + left,n*h_spacing + left], [layer_top - m*v_spacing - 0.5*node_size*2,layer_top - m*v_spacing + 0.5*node_size*2]
if draw_bias and m==0 and n<len(layer_sizes)-1:
ax.text(nx, ny, r'$1$', **txtargs)
elif labels and n==0:
ax.text(n*h_spacing + left,layer_top + v_spacing/1.5, 'input', **txtargs)
ax.text(nx, ny, r'$x_{}$'.format(m), **txtargs)
elif labels and n==len(layer_sizes)-1:
if activation:
if sigmoid:
ax.text(n*h_spacing + left,layer_top - m*v_spacing, r"$z \;\;\; \sigma$", **txtargs)
else:
ax.text(n*h_spacing + left,layer_top - m*v_spacing, r"$z_{} \;\; g$".format(m), **txtargs)
ax.add_artist(plt.Line2D(nsx, nsy, c='k', zorder=6))
if show_activations:
ax.text(n*h_spacing + left + 1.5*node_size,layer_top - m*v_spacing, r"$\hat{y}$", fontsize=12,
verticalalignment='center', horizontalalignment='left', zorder=5, c='r')
else:
ax.text(nx, ny, r'$o_{}$'.format(m), **txtargs)
ax.text(n*h_spacing + left,layer_top + v_spacing, 'output', **txtargs)
elif labels:
if activation:
ax.text(n*h_spacing + left,layer_top - m*v_spacing, r"$z_{} \;\; f$".format(m), **txtargs)
ax.add_artist(plt.Line2D(nsx, nsy, c='k', zorder=6))
if show_activations:
ax.text(n*h_spacing + left + node_size,layer_top - m*v_spacing, r"$a_{}$".format(m), fontsize=12,
verticalalignment='center', horizontalalignment='left', zorder=5, c='b')
else:
ax.text(nx, ny, r'$h_{}$'.format(m), **txtargs)
# Edges
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
for m in range(layer_size_a):
for o in range(layer_size_b):
if not (draw_bias and o==0 and len(layer_sizes)>2 and n<layer_size_b-1):
xs = [n*h_spacing + left, (n + 1)*h_spacing + left]
ys = [layer_top_a - m*v_spacing, layer_top_b - o*v_spacing]
color = 'k' if not random_weights else plt.cm.bwr(np.random.random())
ax.add_artist(plt.Line2D(xs, ys, c=color, lw=1, alpha=0.6))
if not (draw_bias and m==0):
w_count += 1
if labels and not random_weights:
wl = r'$w_{{{},{}}}$'.format(m,o) if layer_size_b>1 else r'$w_{}$'.format(m)
ax.text(xs[0]+np.diff(xs)/2, np.mean(ys)-np.diff(ys)/9, wl, ha='center', va='center',
fontsize=10)
# Count
if weight_count:
b_count = b_count - layer_sizes[0]
if draw_bias:
b_count = b_count - (len(layer_sizes) - 2)
ax.text(right, bottom, "{} weights, {} biases".format(w_count, b_count), ha='center', va='center')
Linear models as a building block¶
Logistic regression, drawn in a different, neuro-inspired, way
Linear model: inner product (\(z\)) of input vector \(\mathbf{x}\) and weight vector \(\mathbf{w}\), plus bias \(w_0\)
Logistic (or sigmoid) function maps the output to a probability in [0,1]
Uses log loss (cross-entropy) and gradient descent to learn the weights
fig = plt.figure(figsize=(3*fig_scale, 3*fig_scale))
ax = fig.gca()
draw_neural_net(ax, [4, 1], activation=True, draw_bias=True, labels=True, sigmoid=True)
Basic Architecture¶
Add one (or more) hidden layers \(h\) with \(k\) nodes (or units, cells, neurons)
Every ‘neuron’ is a tiny function, the network is an arbitrarily complex function
Weights \(w_{i,j}\) between node \(i\) and node \(j\) form a weight matrix \(\mathbf{W}^{(l)}\) per layer \(l\)
Every neuron weights the inputs \(\mathbf{x}\) and passes it through a non-linear activation function
Activation functions (\(f,g\)) can be different per layer, output \(\mathbf{a}\) is called activation $\(\color{blue}{h(\mathbf{x})} = \color{blue}{\mathbf{a}} = f(\mathbf{z}) = f(\mathbf{W}^{(1)} \color{green}{\mathbf{x}}+\mathbf{w}^{(1)}_0) \quad \quad \color{red}{o(\mathbf{x})} = g(\mathbf{W}^{(2)} \color{blue}{\mathbf{a}}+\mathbf{w}^{(2)}_0)\)$
fig, axes = plt.subplots(1,2, figsize=(8, 4))
draw_neural_net(axes[0], [2, 3, 1], draw_bias=True, labels=True, weight_count=True)
draw_neural_net(axes[1], [2, 3, 1], activation=True, show_activations=True, draw_bias=True, labels=True, weight_count=True)
More layers¶
Add more layers, and more nodes per layer, to make the model more complex
For simplicity, we don’t draw the biases (but remember that they are there)
In dense (fully-connected) layers, every previous layer node is connected to all nodes
The output layer can also have multiple nodes (e.g. 1 per class in multi-class classification)
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
@interact
def plot_dense_net(nr_layers=(0,6,1), nr_nodes=(1,12,1)):
fig = plt.figure(figsize=(6, 4))
ax = fig.gca()
ax.axis('off')
hidden = [nr_nodes]*nr_layers
draw_neural_net(ax, [5] + hidden + [5], weight_count=True, figsize=(6, 4));
if not interactive:
plot_dense_net(nr_layers=6, nr_nodes=10)
Why layers?¶
Each layer acts as a filter and learns a new representation of the data
Subsequent layers can learn iterative refinements
Easier that learning a complex relationship in one go
Example: for image input, each layer yields new (filtered) images
Can learn multiple mappings at once: weight tensor \(\mathit{W}\) yields activation tensor \(\mathit{A}\)
From low-level patterns (edges, end-points, …) to combinations thereof
Each neuron ‘lights up’ if certain patterns occur in the input

Other architectures¶
There exist MANY types of networks for many different tasks
Convolutional nets for image data, Recurrent nets for sequential data,…
Also used to learn representations (embeddings), generate new images, text,…

Training Neural Nets¶
Design the architecture, choose activation functions (e.g. sigmoids)
Choose a way to initialize the weights (e.g. random initialization)
Choose a loss function (e.g. log loss) to measure how well the model fits training data
Choose an optimizer (typically an SGD variant) to update the weights

Mini-batch Stochastic Gradient Descent (recap)¶
Draw a batch of batch_size training data \(\mathbf{X}\) and \(\mathbf{y}\)
Forward pass : pass \(\mathbf{X}\) though the network to yield predictions \(\mathbf{\hat{y}}\)
Compute the loss \(\mathcal{L}\) (mismatch between \(\mathbf{\hat{y}}\) and \(\mathbf{y}\))
Backward pass : Compute the gradient of the loss with regard to every weight
Backpropagate the gradients through all the layers
Update \(W\): \(W_{(i+1)} = W_{(i)} - \frac{\partial L(x, W_{(i)})}{\partial W} * \eta\)
Repeat until n passes (epochs) are made through the entire training set
# TODO: show the actual weight updates
@interact
def draw_updates(iteration=(1,100,1)):
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
np.random.seed(iteration)
draw_neural_net(ax, [6,5,5,3], labels=True, random_weights=True, show_activations=True, figsize=(6, 4));
if not interactive:
draw_updates(iteration=1)
Forward pass¶
We can naturally represent the data as tensors
Numerical n-dimensional array (with n axes)
2D tensor: matrix (samples, features)
3D tensor: time series (samples, timesteps, features)
4D tensor: color images (samples, height, width, channels)
5D tensor: video (samples, frames, height, width, channels)

Tensor operations¶
The operations that the network performs on the data can be reduced to a series of tensor operations
These are also much easier to run on GPUs
A dense layer with sigmoid activation, input tensor \(\mathbf{X}\), weight tensor \(\mathbf{W}\), bias \(\mathbf{b}\):
y = sigmoid(np.dot(X, W) + b)
Tensor dot product for 2D inputs (\(a\) samples, \(b\) features, \(c\) hidden nodes)

Element-wise operations¶
Activation functions and addition are element-wise operations:
def sigmoid(x):
return 1/(1 + np.exp(-x))
def add(x, y):
return x + y
Note: if y has a lower dimension than x, it will be broadcasted: axes are added to match the dimensionality, and y is repeated along the new axes
>>> np.array([[1,2],[3,4]]) + np.array([10,20])
array([[11, 22],
[13, 24]])
Backward pass (backpropagation)¶
For last layer, compute gradient of the loss function \(\mathcal{L}\) w.r.t all weights of layer \(l\)
Sum up the gradients for all \(\mathbf{x}_j\) in minibatch: \(\sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W^{(l)}}\)
Update all weights in a layer at once (with learning rate \(\eta\)): \(W_{(i+1)}^{(l)} = W_{(i)}^{(l)} - \eta \sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W_{(i)}^{(l)}}\)
Repeat for next layer, iterating backwards (most efficient, avoids redundant calculations)

Backpropagation (example)¶
Imagine feeding a single data point, output is \(\hat{y} = g(z) = g(w_0 + w_1 * a_1 + w_2 * a_2 +... + w_p * a_p)\)
Decrease loss by updating weights:
Update the weights of last layer to maximize improvement: \(w_{i,(new)} = w_{i} - \frac{\partial \mathcal{L}}{\partial w_i} * \eta\)
To compute gradient \(\frac{\partial \mathcal{L}}{\partial w_i}\) we need the chain rule: \(f(g(x)) = f'(g(x)) * g'(x)\) $\(\frac{\partial \mathcal{L}}{\partial w_i} = \color{red}{\frac{\partial \mathcal{L}}{\partial g}} \color{blue}{\frac{\partial \mathcal{g}}{\partial z_0}} \color{green}{\frac{\partial \mathcal{z_0}}{\partial w_i}}\)$
E.g., with \(\mathcal{L} = \frac{1}{2}(y-\hat{y})^2\) and sigmoid \(\sigma\): \(\frac{\partial \mathcal{L}}{\partial w_i} = \color{red}{(y - \hat{y})} * \color{blue}{\sigma'(z_0)} * \color{green}{a_i}\)
fig = plt.figure(figsize=(4, 3.5))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 1], activation=True, draw_bias=True, labels=True,
show_activations=True)
Backpropagation (2)¶
Another way to decrease the loss \(\mathcal{L}\) is to update the activations \(a_i\)
To update \(a_i = f(z_i)\), we need to update the weights of the previous layer
We want to nudge \(a_i\) in the right direction by updating \(w_{i,j}\): $\(\frac{\partial \mathcal{L}}{\partial w_{i,j}} = \frac{\partial \mathcal{L}}{\partial a_i} \frac{\partial a_i}{\partial z_i} \frac{\partial \mathcal{z_i}}{\partial w_{i,j}} = \left( \frac{\partial \mathcal{L}}{\partial g} \frac{\partial \mathcal{g}}{\partial z_0} \frac{\partial \mathcal{z_0}}{\partial a_i} \right) \frac{\partial a_i}{\partial z_i} \frac{\partial \mathcal{z_i}}{\partial w_{i,j}}\)$
We know \(\frac{\partial \mathcal{L}}{\partial g}\) and \(\frac{\partial \mathcal{g}}{\partial z_0}\) from the previous step, \(\frac{\partial \mathcal{z_0}}{\partial a_i} = w_i\), \(\frac{\partial a_i}{\partial z_i} = f'\) and \(\frac{\partial \mathcal{z_i}}{\partial w_{i,j}} = x_j\)
fig = plt.figure(figsize=(4, 4))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 1], activation=True, draw_bias=True, labels=True,
show_activations=True)
Backpropagation (3)¶
With multiple output nodes, \(\mathcal{L}\) is the sum of all per-output (per-class) losses
\(\frac{\partial \mathcal{L}}{\partial a_i}\) is sum of the gradients for every output
Per layer, sum up gradients for every point \(\mathbf{x}\) in the batch: \(\sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W}\)
Update all weights of every layer \(l\)
\(W_{(i+1)}^{(l)} = W_{(i)}^{(l)} - \eta \sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W_{(i)}^{(l)}}\)
Repeat with a new batch of data until loss converges
fig = plt.figure(figsize=(8, 4))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 3, 2], activation=True, draw_bias=True, labels=True,
random_weights=True, show_activations=True, figsize=(8, 4))
Backpropagation (summary)¶
The network output \(a_o\) is defined by the weights \(W^{(o)}\) and biases \(\mathbf{b}^{(o)}\) of the output layer, and
The activations of a hidden layer \(h_1\) with activation function \(a_{h_1}\), weights \(W^{(1)}\) and biases \(\mathbf{b^{(1)}}\):
Minimize the loss by SGD. For layer \(l\), compute \(\frac{\partial \mathcal{L}(a_o(x))}{\partial W_l}\) and \(\frac{\partial \mathcal{L}(a_o(x))}{\partial b_{l,i}}\) using the chain rule
Decomposes into gradient of layer above, gradient of activation function, gradient of layer input:

Weight initialization¶
Initializing weights to 0 is bad: all gradients in layer will be identical (symmetry)
Too small random weights shrink activations to 0 along the layers (vanishing gradient)
Too large random weights multiply along layers (exploding gradient, zig-zagging)
Ideal: small random weights + variance of input and output gradients remains the same
Glorot/Xavier initialization (for tanh): randomly sample from \(N(0,\sigma), \sigma = \sqrt{\frac{2}{\text{fan_in + fan_out}}}\)
fan_in: number of input units, fan_out: number of output units
He initialization (for ReLU): randomly sample from \(N(0,\sigma), \sigma = \sqrt{\frac{2}{\text{fan_in}}}\)
Uniform sampling (instead of \(N(0,\sigma)\)) for deeper networks (w.r.t. vanishing gradients)
fig, ax = plt.subplots(1,1, figsize=(6, 3))
draw_neural_net(ax, [3, 5, 5, 5, 5, 5, 3], random_weights=True, figsize=(6, 3))
Weight initialization: transfer learning¶
Instead of starting from scratch, start from weights previously learned from similar tasks
This is, to a big extent, how humans learn so fast
Transfer learning: learn weights on task T, transfer them to new network
Weights can be frozen, or finetuned to the new data
Only works if the previous task is ‘similar’ enough
Meta-learning: learn a good initialization across many related tasks

## Code adapted from Il Gu Yi: https://github.com/ilguyi/optimizers.numpy
from matplotlib.colors import LogNorm
import tensorflow as tf
import tensorflow_addons as tfa
# Toy surface
def f(x, y):
return (1.5 - x + x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2
# Tensorflow optimizers
sgd = tf.optimizers.SGD(0.01)
lr_schedule = tf.optimizers.schedules.ExponentialDecay(0.02,decay_steps=100,decay_rate=0.96)
sgd_decay = tf.optimizers.SGD(learning_rate=lr_schedule)
#sgd_cyclic = tfa.optimizers.CyclicalLearningRate(initial_learning_rate= 0.1,
#maximal_learning_rate= 0.5, step_size=0.05)
#clr_schedule = tfa.optimizers.CyclicalLearningRate(initial_learning_rate=1e-4, maximal_learning_rate= 0.1,
# step_size=100, scale_fn=lambda x : x)
#sgd_cyclic = tf.optimizers.SGD(learning_rate=clr_schedule)
momentum = tf.optimizers.SGD(0.005, momentum=0.9, nesterov=False)
nesterov = tf.optimizers.SGD(0.005, momentum=0.9, nesterov=True)
adagrad = tf.optimizers.Adagrad(0.4)
#adamax = tf.optimizers.Adamax(learning_rate=0.5, beta_1=0.9, beta_2=0.999)
#adadelta = tf.optimizers.Adadelta(learning_rate=1.0)
rmsprop = tf.optimizers.RMSprop(learning_rate=0.1)
rmsprop_momentum = tf.optimizers.RMSprop(learning_rate=0.1, momentum=0.9)
adam = tf.optimizers.Adam(learning_rate=0.2, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
optimizers = [sgd, sgd_decay, momentum, nesterov, adagrad, rmsprop, rmsprop_momentum, adam]#, sgd_cyclic, adamax]
opt_names = ['sgd', 'sgd_decay', 'momentum', 'nesterov', 'adagrad', 'rmsprop', 'rmsprop_mom', 'adam']#, 'sgd_cyclic','adamax']
cmap = plt.cm.get_cmap('tab10')
colors = [cmap(x/10) for x in range(10)]
# Training
all_paths = []
for opt, name in zip(optimizers, opt_names):
x_init = 0.8
x = tf.compat.v1.get_variable('x', dtype=tf.float32, initializer=tf.constant(x_init))
y_init = 1.6
y = tf.compat.v1.get_variable('y', dtype=tf.float32, initializer=tf.constant(y_init))
x_history = []
y_history = []
z_prev = 0.0
max_steps = 100
for step in range(max_steps):
with tf.GradientTape() as g:
z = f(x, y)
x_history.append(x.numpy())
y_history.append(y.numpy())
dz_dx, dz_dy = g.gradient(z, [x, y])
opt.apply_gradients(zip([dz_dx, dz_dy], [x, y]))
if np.abs(z_prev - z.numpy()) < 1e-6:
break
z_prev = z.numpy()
x_history = np.array(x_history)
y_history = np.array(y_history)
path = np.concatenate((np.expand_dims(x_history, 1), np.expand_dims(y_history, 1)), axis=1).T
all_paths.append(path)
# Plotting
number_of_points = 50
margin = 4.5
minima = np.array([3., .5])
minima_ = minima.reshape(-1, 1)
x_min = 0. - 2
x_max = 0. + 3.5
y_min = 0. - 3.5
y_max = 0. + 2
x_points = np.linspace(x_min, x_max, number_of_points)
y_points = np.linspace(y_min, y_max, number_of_points)
x_mesh, y_mesh = np.meshgrid(x_points, y_points)
z = np.array([f(xps, yps) for xps, yps in zip(x_mesh, y_mesh)])
def plot_optimizers(ax, iterations, optimizers):
ax.contour(x_mesh, y_mesh, z, levels=np.logspace(-0.5, 5, 25), norm=LogNorm(), cmap=plt.cm.jet)
ax.plot(*minima, 'r*', markersize=20)
for name, path, color in zip(opt_names, all_paths, colors):
if name in optimizers:
p = path[:,:iterations]
ax.quiver(p[0,:-1], p[1,:-1], p[0,1:]-p[0,:-1], p[1,1:]-p[1,:-1], scale_units='xy', angles='xy', scale=1, color=color, lw=3)
ax.plot([], [], color=color, label=name, lw=3, linestyle='-')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_xlim((x_min, x_max))
ax.set_ylim((y_min, y_max))
ax.legend(loc='lower left', prop={'size': 15})
plt.tight_layout()
2022-03-16 12:43:05.283936: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-03-16 12:43:05.284367: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Metal device set to: Apple M1 Pro
# Toy plot to illustrate nesterov momentum
# TODO: replace with actual gradient computation?
def plot_nesterov(ax, method="Nesterov momentum"):
ax.contour(x_mesh, y_mesh, z, levels=np.logspace(-0.5, 5, 25), norm=LogNorm(), cmap=plt.cm.jet)
ax.plot(*minima, 'r*', markersize=20)
# toy example
ax.quiver(-0.8,-1.13,1,1.33, scale_units='xy', angles='xy', scale=1, color='k', alpha=0.5, lw=3, label="previous update")
# 0.9 * previous update
ax.quiver(0.2,0.2,0.9,1.2, scale_units='xy', angles='xy', scale=1, color='g', lw=3, label="momentum step")
if method == "Momentum":
ax.quiver(0.2,0.2,0.5,0, scale_units='xy', angles='xy', scale=1, color='r', lw=3, label="gradient step")
ax.quiver(0.2,0.2,0.9*0.9+0.5,1.2, scale_units='xy', angles='xy', scale=1, color='b', lw=3, label="actual step")
if method == "Nesterov momentum":
ax.quiver(1.1,1.4,-0.2,-1, scale_units='xy', angles='xy', scale=1, color='r', lw=3, label="'lookahead' gradient step")
ax.quiver(0.2,0.2,0.7,0.2, scale_units='xy', angles='xy', scale=1, color='b', lw=3, label="actual step")
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title(method)
ax.set_xlim((x_min, x_max))
ax.set_ylim((-2.5, y_max))
ax.legend(loc='lower right', prop={'size': 9})
plt.tight_layout()
Optimizers¶
SGD with learning rate schedules¶
Using a constant learning \(\eta\) rate for weight updates \(\mathbf{w}_{(s+1)} = \mathbf{w}_s-\eta\nabla \mathcal{L}(\mathbf{w}_s)\) is not ideal
Learning rate decay/annealing with decay rate \(k\)
E.g. exponential (\(\eta_{s+1} = \eta_{s} e^{-ks}\)), inverse-time (\(\eta_{s+1} = \frac{\eta_{0}}{1+ks}\)),…
Cyclical learning rates
Change from small to large: hopefully in ‘good’ region long enough before diverging
Warm restarts: aggressive decay + reset to initial learning rate
@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
fig, ax = plt.subplots(figsize=(6,4))
plot_optimizers(ax,iterations,[optimizer1,optimizer2])
if not interactive:
fig, axes = plt.subplots(1,2, figsize=(10,3))
optimizers = ['sgd_decay', 'sgd_cyclic']
for function, ax in zip(optimizers,axes):
plot_optimizers(ax,100,function)
plt.tight_layout();
Momentum¶
Imagine a ball rolling downhill: accumulates momentum, doesn’t exactly follow steepest descent
Reduces oscillation, follows larger (consistent) gradient of the loss surface
Adds a velocity vector \(\mathbf{v}\) with momentum \(\gamma\) (e.g. 0.9, or increase from \(\gamma=0.5\) to \(\gamma=0.99\)) $\(\mathbf{w}_{(s+1)} = \mathbf{w}_{(s)} + \mathbf{v}_{(s)} \qquad \text{with} \qquad \color{blue}{\mathbf{v}_{(s)}} = \color{green}{\gamma \mathbf{v}_{(s-1)}} - \color{red}{\eta \nabla \mathcal{L}(\mathbf{w}_{(s)})}\)$
Nesterov momentum: Look where momentum step would bring you, compute gradient there
Responds faster (and reduces momentum) when the gradient changes $\(\color{blue}{\mathbf{v}_{(s)}} = \color{green}{\gamma \mathbf{v}_{(s-1)}} - \color{red}{\eta \nabla \mathcal{L}(\mathbf{w}_{(s)} + \gamma \mathbf{v}_{(s-1)})}\)$
fig, axes = plt.subplots(1,2, figsize=(10,2.6))
plot_nesterov(axes[0],method="Momentum")
plot_nesterov(axes[1],method="Nesterov momentum")
Momentum in practice¶
@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
fig, ax = plt.subplots(figsize=(6,4))
plot_optimizers(ax,iterations,[optimizer1,optimizer2])
if not interactive:
fig, axes = plt.subplots(1,2, figsize=(10,3.5))
optimizers = [['sgd','momentum'], ['momentum','nesterov']]
for function, ax in zip(optimizers,axes):
plot_optimizers(ax,100,function)
plt.tight_layout();
Adaptive gradients¶
‘Correct’ the learning rate for each \(w_i\) based on specific local conditions (layer depth, fan-in,…)
Adagrad: scale \(\eta\) according to squared sum of previous gradients \(G_{i,(s)} = \sum_{t=1}^s \mathcal{L}(w_{i,(t)})^2\)
Update rule for \(w_i\). Usually \(\epsilon=10^{-7}\) (avoids division by 0), \(\eta=0.001\). $\(w_{i,(s+1)} = w_{i,(s)} - \frac{\eta}{\sqrt{G_{i,(s)}+\epsilon}} \nabla \mathcal{L}(w_{i,(s)})\)$
RMSProp: use moving average of squared gradients \(m_{i,(s)} = \gamma m_{i,(s-1)} + (1-\gamma) \nabla \mathcal{L}(w_{i,(s)})^2\)
Avoids that gradients dwindle to 0 as \(G_{i,(s)}\) grows. Usually \(\gamma=0.9, \eta=0.001\) $\(w_{i,(s+1)} = w_{i,(s)}- \frac{\eta}{\sqrt{m_{i,(s)}+\epsilon}} \nabla \mathcal{L}(w_{i,(s)})\)$
if not interactive:
fig, axes = plt.subplots(1,2, figsize=(10,2.6))
optimizers = [['sgd','adagrad', 'rmsprop'], ['rmsprop','rmsprop_mom']]
for function, ax in zip(optimizers,axes):
plot_optimizers(ax,100,function)
plt.tight_layout();
@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
fig, ax = plt.subplots(figsize=(6,4))
plot_optimizers(ax,iterations,[optimizer1,optimizer2])
Adam (Adaptive moment estimation)¶
Adam: RMSProp + momentum. Adds moving average for gradients as well (\(\gamma_2\) = momentum):
Adds a bias correction to avoid small initial gradients: \(\hat{m}_{i,(s)} = \frac{m_{i,(s)}}{1-\gamma}\) and \(\hat{g}_{i,(s)} = \frac{g_{i,(s)}}{1-\gamma_2}\) $\(g_{i,(s)} = \gamma_2 g_{i,(s-1)} + (1-\gamma_2) \nabla \mathcal{L}(w_{i,(s)})\)\( \)\(w_{i,(s+1)} = w_{i,(s)}- \frac{\eta}{\sqrt{\hat{m}_{i,(s)}+\epsilon}} \hat{g}_{i,(s)}\)$
Adamax: Idem, but use max() instead of moving average: \(u_{i,(s)} = max(\gamma u_{i,(s-1)}, |\mathcal{L}(w_{i,(s)})|)\) $\(w_{i,(s+1)} = w_{i,(s)}- \frac{\eta}{u_{i,(s)}} \hat{g}_{i,(s)}\)$
if not interactive:
fig, axes = plt.subplots(1,2, figsize=(10,2.6))
optimizers = [['sgd','adam'], ['adam','adamax']]
for function, ax in zip(optimizers,axes):
plot_optimizers(ax,100,function)
plt.tight_layout();
@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
fig, ax = plt.subplots(figsize=(6,4))
plot_optimizers(ax,iterations,[optimizer1,optimizer2])
SGD Optimizer Zoo¶
RMSProp often works well, but do try alternatives. For even more optimizers, see here.
if not interactive:
fig, ax = plt.subplots(1,1, figsize=(10,5.5))
plot_optimizers(ax,100,opt_names)
@interact
def compare_optimizers(iterations=(1,100,1)):
fig, ax = plt.subplots(figsize=(10,6))
plot_optimizers(ax,iterations,opt_names)
from tensorflow.keras import models
from tensorflow.keras import layers
from numpy.random import seed
from tensorflow.random import set_seed
import random
import os
#Trying to set all seeds
os.environ['PYTHONHASHSEED']=str(0)
random.seed(0)
seed(0)
set_seed(0)
seed_value= 0
Neural networks in practice¶
There are many practical courses on training neural nets. E.g.:
With TensorFlow: https://www.tensorflow.org/resources/learn-ml
With PyTorch: fast.ai course, https://pytorch.org/tutorials/
Here, we’ll use Keras, a general API for building neural networks
Default API for TensorFlow, also has backends for CNTK, Theano
Focus on key design decisions, evaluation, and regularization
Running example: Fashion-MNIST
28x28 pixel images of 10 classes of fashion items
# Download FMINST data. Takes a while the first time.
mnist = oml.datasets.get_dataset(40996)
X, y, _, _ = mnist.get_data(target=mnist.default_target_attribute, dataset_format='array');
X = X.reshape(70000, 28, 28)
fmnist_classes = {0:"T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal",
6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}
# Take some random examples
from random import randint
fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
n = randint(0,70000)
axes[i].imshow(X[n], cmap=plt.cm.gray_r)
axes[i].set_xticks([])
axes[i].set_yticks([])
axes[i].set_xlabel("{}".format(fmnist_classes[y[n]]))
plt.show();
Building the network¶
We first build a simple sequential model (no branches)
Input layer (‘input_shape’): a flat vector of 28*28=784 nodes
We’ll see how to properly deal with images later
Two dense hidden layers: 512 nodes each, ReLU activation
Glorot weight initialization is applied by default
Output layer: 10 nodes (for 10 classes) and softmax activation
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
from tensorflow.keras import initializers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
Model summary¶
Lots of parameters (weights and biases) to learn!
hidden layer 1 : (28 * 28 + 1) * 512 = 401920
hidden layer 2 : (512 + 1) * 512 = 262656
output layer: (512 + 1) * 10 = 5130
network.summary()
network.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
dense_1 (Dense) (None, 512) 262656
dense_2 (Dense) (None, 10) 5130
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
_________________________________________________________________
Choosing loss, optimizer, metrics¶
Loss function
Cross-entropy (log loss) for multi-class classification (\(y_{true}\) is one-hot encoded)
Use binary crossentropy for binary problems (single output node)
Use sparse categorical crossentropy if \(y_{true}\) is label-encoded (1,2,3,…)
Optimizer
Any of the optimizers we discussed before. RMSprop usually works well.
Metrics
To monitor performance during training and testing, e.g. accuracy
# Shorthand
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
# Detailed
network.compile(loss=CategoricalCrossentropy(label_smoothing=0.01),
optimizer=RMSprop(learning_rate=0.001, momentum=0.0)
metrics=[Accuracy()])
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import Accuracy
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
Preprocessing: Normalization, Reshaping, Encoding¶
Always normalize (standardize or min-max) the inputs. Mean should be close to 0.
Avoid that some inputs overpower others
Speed up convergence
Gradients of activation functions \(\frac{\partial a_{h}}{\partial z_{h}}\) are (near) 0 for large inputs
If some gradients become much larger than others, SGD will start zig-zagging
Reshape the data to fit the shape of the input layer, e.g. (n, 28*28) or (n, 28,28)
Tensor with instances in first dimension, rest must match the input layer
In multi-class classification, every class is an output node, so one-hot-encode the labels
e.g. class ‘4’ becomes [0,0,0,0,1,0,0,0,0,0]
X = X.astype('float32') / 255
X = X.reshape((60000, 28 * 28))
y = to_categorical(y)
from sklearn.model_selection import train_test_split
Xf_train, Xf_test, yf_train, yf_test = train_test_split(X, y, train_size=60000, shuffle=True, random_state=0)
Xf_train = Xf_train.reshape((60000, 28 * 28))
Xf_test = Xf_test.reshape((10000, 28 * 28))
# TODO: check if standardization works better
Xf_train = Xf_train.astype('float32') / 255
Xf_test = Xf_test.astype('float32') / 255
from tensorflow.keras.utils import to_categorical
yf_train = to_categorical(yf_train)
yf_test = to_categorical(yf_test)
Choosing training hyperparameters¶
Number of epochs: enough to allow convergence
Too much: model starts overfitting (or just wastes time)
Batch size: small batches (e.g. 32, 64,… samples) often preferred
‘Noisy’ training data makes overfitting less likely
Larger batches generalize less well (‘generalization gap’)
Requires less memory (especially in GPUs)
Large batches do speed up training, may converge in fewer epochs
Batch size interacts with learning rate
Instead of shrinking the learning rate you can increase batch size
history = network.fit(X_train, y_train, epochs=3, batch_size=32);
history = network.fit(Xf_train, yf_train, epochs=3, batch_size=32);
2022-03-16 12:43:34.490590: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/3
2022-03-16 12:43:34.772914: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
1/1875 [..............................] - ETA: 15:39 - loss: 2.4362 - accuracy: 0.0312
9/1875 [..............................] - ETA: 13s - loss: 2.0937 - accuracy: 0.3472
17/1875 [..............................] - ETA: 12s - loss: 1.5483 - accuracy: 0.5018
25/1875 [..............................] - ETA: 12s - loss: 1.3969 - accuracy: 0.5375
34/1875 [..............................] - ETA: 11s - loss: 1.2232 - accuracy: 0.5855
42/1875 [..............................] - ETA: 11s - loss: 1.1373 - accuracy: 0.6124
51/1875 [..............................] - ETA: 11s - loss: 1.0716 - accuracy: 0.6330
60/1875 [..............................] - ETA: 11s - loss: 1.0320 - accuracy: 0.6422
68/1875 [>.............................] - ETA: 11s - loss: 0.9942 - accuracy: 0.6558
77/1875 [>.............................] - ETA: 11s - loss: 0.9666 - accuracy: 0.6664
85/1875 [>.............................] - ETA: 11s - loss: 0.9478 - accuracy: 0.6724
94/1875 [>.............................] - ETA: 11s - loss: 0.9263 - accuracy: 0.6785
102/1875 [>.............................] - ETA: 11s - loss: 0.8979 - accuracy: 0.6857
111/1875 [>.............................] - ETA: 11s - loss: 0.8843 - accuracy: 0.6914
119/1875 [>.............................] - ETA: 11s - loss: 0.8619 - accuracy: 0.6993
127/1875 [=>............................] - ETA: 11s - loss: 0.8515 - accuracy: 0.7030
135/1875 [=>............................] - ETA: 10s - loss: 0.8332 - accuracy: 0.7088
143/1875 [=>............................] - ETA: 10s - loss: 0.8263 - accuracy: 0.7120
151/1875 [=>............................] - ETA: 11s - loss: 0.8147 - accuracy: 0.7146
159/1875 [=>............................] - ETA: 10s - loss: 0.8036 - accuracy: 0.7188
167/1875 [=>............................] - ETA: 11s - loss: 0.7937 - accuracy: 0.7214
175/1875 [=>............................] - ETA: 11s - loss: 0.7858 - accuracy: 0.7236
183/1875 [=>............................] - ETA: 10s - loss: 0.7803 - accuracy: 0.7249
191/1875 [==>...........................] - ETA: 10s - loss: 0.7702 - accuracy: 0.7279
199/1875 [==>...........................] - ETA: 10s - loss: 0.7650 - accuracy: 0.7283
207/1875 [==>...........................] - ETA: 10s - loss: 0.7608 - accuracy: 0.7296
215/1875 [==>...........................] - ETA: 10s - loss: 0.7579 - accuracy: 0.7301
223/1875 [==>...........................] - ETA: 10s - loss: 0.7550 - accuracy: 0.7298
231/1875 [==>...........................] - ETA: 10s - loss: 0.7467 - accuracy: 0.7321
240/1875 [==>...........................] - ETA: 10s - loss: 0.7414 - accuracy: 0.7337
248/1875 [==>...........................] - ETA: 10s - loss: 0.7377 - accuracy: 0.7341
256/1875 [===>..........................] - ETA: 10s - loss: 0.7344 - accuracy: 0.7355
264/1875 [===>..........................] - ETA: 10s - loss: 0.7305 - accuracy: 0.7377
273/1875 [===>..........................] - ETA: 10s - loss: 0.7260 - accuracy: 0.7378
281/1875 [===>..........................] - ETA: 10s - loss: 0.7195 - accuracy: 0.7409
289/1875 [===>..........................] - ETA: 10s - loss: 0.7128 - accuracy: 0.7426
297/1875 [===>..........................] - ETA: 10s - loss: 0.7095 - accuracy: 0.7440
305/1875 [===>..........................] - ETA: 10s - loss: 0.7024 - accuracy: 0.7468
314/1875 [====>.........................] - ETA: 10s - loss: 0.6995 - accuracy: 0.7483
323/1875 [====>.........................] - ETA: 10s - loss: 0.6932 - accuracy: 0.7501
331/1875 [====>.........................] - ETA: 10s - loss: 0.6903 - accuracy: 0.7517
339/1875 [====>.........................] - ETA: 9s - loss: 0.6874 - accuracy: 0.7526
347/1875 [====>.........................] - ETA: 9s - loss: 0.6812 - accuracy: 0.7551
356/1875 [====>.........................] - ETA: 9s - loss: 0.6818 - accuracy: 0.7558
364/1875 [====>.........................] - ETA: 9s - loss: 0.6767 - accuracy: 0.7582
372/1875 [====>.........................] - ETA: 9s - loss: 0.6722 - accuracy: 0.7594
380/1875 [=====>........................] - ETA: 9s - loss: 0.6695 - accuracy: 0.7596
388/1875 [=====>........................] - ETA: 9s - loss: 0.6675 - accuracy: 0.7608
396/1875 [=====>........................] - ETA: 9s - loss: 0.6649 - accuracy: 0.7614
404/1875 [=====>........................] - ETA: 9s - loss: 0.6628 - accuracy: 0.7617
412/1875 [=====>........................] - ETA: 9s - loss: 0.6593 - accuracy: 0.7628
420/1875 [=====>........................] - ETA: 9s - loss: 0.6584 - accuracy: 0.7632
428/1875 [=====>........................] - ETA: 9s - loss: 0.6567 - accuracy: 0.7638
436/1875 [=====>........................] - ETA: 9s - loss: 0.6537 - accuracy: 0.7645
444/1875 [======>.......................] - ETA: 9s - loss: 0.6517 - accuracy: 0.7653
452/1875 [======>.......................] - ETA: 9s - loss: 0.6506 - accuracy: 0.7658
458/1875 [======>.......................] - ETA: 9s - loss: 0.6476 - accuracy: 0.7669
466/1875 [======>.......................] - ETA: 9s - loss: 0.6442 - accuracy: 0.7679
474/1875 [======>.......................] - ETA: 9s - loss: 0.6428 - accuracy: 0.7683
482/1875 [======>.......................] - ETA: 9s - loss: 0.6409 - accuracy: 0.7691
490/1875 [======>.......................] - ETA: 9s - loss: 0.6368 - accuracy: 0.7706
498/1875 [======>.......................] - ETA: 8s - loss: 0.6359 - accuracy: 0.7708
506/1875 [=======>......................] - ETA: 8s - loss: 0.6340 - accuracy: 0.7713
515/1875 [=======>......................] - ETA: 8s - loss: 0.6351 - accuracy: 0.7706
524/1875 [=======>......................] - ETA: 8s - loss: 0.6330 - accuracy: 0.7714
532/1875 [=======>......................] - ETA: 8s - loss: 0.6308 - accuracy: 0.7723
541/1875 [=======>......................] - ETA: 8s - loss: 0.6291 - accuracy: 0.7729
550/1875 [=======>......................] - ETA: 8s - loss: 0.6272 - accuracy: 0.7737
559/1875 [=======>......................] - ETA: 8s - loss: 0.6248 - accuracy: 0.7748
562/1875 [=======>......................] - ETA: 8s - loss: 0.6250 - accuracy: 0.7748
570/1875 [========>.....................] - ETA: 8s - loss: 0.6232 - accuracy: 0.7751
578/1875 [========>.....................] - ETA: 8s - loss: 0.6221 - accuracy: 0.7756
584/1875 [========>.....................] - ETA: 8s - loss: 0.6207 - accuracy: 0.7762
592/1875 [========>.....................] - ETA: 8s - loss: 0.6212 - accuracy: 0.7761
600/1875 [========>.....................] - ETA: 8s - loss: 0.6192 - accuracy: 0.7768
608/1875 [========>.....................] - ETA: 8s - loss: 0.6177 - accuracy: 0.7774
616/1875 [========>.....................] - ETA: 8s - loss: 0.6161 - accuracy: 0.7777
624/1875 [========>.....................] - ETA: 8s - loss: 0.6159 - accuracy: 0.7781
632/1875 [=========>....................] - ETA: 8s - loss: 0.6132 - accuracy: 0.7791
640/1875 [=========>....................] - ETA: 8s - loss: 0.6123 - accuracy: 0.7797
648/1875 [=========>....................] - ETA: 8s - loss: 0.6109 - accuracy: 0.7800
656/1875 [=========>....................] - ETA: 8s - loss: 0.6085 - accuracy: 0.7808
664/1875 [=========>....................] - ETA: 7s - loss: 0.6072 - accuracy: 0.7810
672/1875 [=========>....................] - ETA: 7s - loss: 0.6050 - accuracy: 0.7814
680/1875 [=========>....................] - ETA: 7s - loss: 0.6043 - accuracy: 0.7816
688/1875 [==========>...................] - ETA: 7s - loss: 0.6027 - accuracy: 0.7822
696/1875 [==========>...................] - ETA: 7s - loss: 0.6013 - accuracy: 0.7826
704/1875 [==========>...................] - ETA: 7s - loss: 0.6003 - accuracy: 0.7828
712/1875 [==========>...................] - ETA: 7s - loss: 0.5991 - accuracy: 0.7835
720/1875 [==========>...................] - ETA: 7s - loss: 0.5981 - accuracy: 0.7841
728/1875 [==========>...................] - ETA: 7s - loss: 0.5966 - accuracy: 0.7845
736/1875 [==========>...................] - ETA: 7s - loss: 0.5962 - accuracy: 0.7847
744/1875 [==========>...................] - ETA: 7s - loss: 0.5944 - accuracy: 0.7855
752/1875 [===========>..................] - ETA: 7s - loss: 0.5929 - accuracy: 0.7860
760/1875 [===========>..................] - ETA: 7s - loss: 0.5916 - accuracy: 0.7866
768/1875 [===========>..................] - ETA: 7s - loss: 0.5912 - accuracy: 0.7867
776/1875 [===========>..................] - ETA: 7s - loss: 0.5907 - accuracy: 0.7868
784/1875 [===========>..................] - ETA: 7s - loss: 0.5893 - accuracy: 0.7873
792/1875 [===========>..................] - ETA: 7s - loss: 0.5878 - accuracy: 0.7878
800/1875 [===========>..................] - ETA: 7s - loss: 0.5867 - accuracy: 0.7886
808/1875 [===========>..................] - ETA: 7s - loss: 0.5860 - accuracy: 0.7889
816/1875 [============>.................] - ETA: 6s - loss: 0.5847 - accuracy: 0.7892
824/1875 [============>.................] - ETA: 6s - loss: 0.5836 - accuracy: 0.7897
832/1875 [============>.................] - ETA: 6s - loss: 0.5820 - accuracy: 0.7903
840/1875 [============>.................] - ETA: 6s - loss: 0.5812 - accuracy: 0.7906
848/1875 [============>.................] - ETA: 6s - loss: 0.5810 - accuracy: 0.7908
856/1875 [============>.................] - ETA: 6s - loss: 0.5794 - accuracy: 0.7913
864/1875 [============>.................] - ETA: 6s - loss: 0.5786 - accuracy: 0.7914
872/1875 [============>.................] - ETA: 6s - loss: 0.5781 - accuracy: 0.7919
880/1875 [=============>................] - ETA: 6s - loss: 0.5777 - accuracy: 0.7919
888/1875 [=============>................] - ETA: 6s - loss: 0.5764 - accuracy: 0.7925
896/1875 [=============>................] - ETA: 6s - loss: 0.5758 - accuracy: 0.7928
904/1875 [=============>................] - ETA: 6s - loss: 0.5761 - accuracy: 0.7928
912/1875 [=============>................] - ETA: 6s - loss: 0.5763 - accuracy: 0.7926
920/1875 [=============>................] - ETA: 6s - loss: 0.5749 - accuracy: 0.7931
928/1875 [=============>................] - ETA: 6s - loss: 0.5738 - accuracy: 0.7932
936/1875 [=============>................] - ETA: 6s - loss: 0.5731 - accuracy: 0.7933
944/1875 [==============>...............] - ETA: 6s - loss: 0.5722 - accuracy: 0.7938
952/1875 [==============>...............] - ETA: 6s - loss: 0.5714 - accuracy: 0.7941
960/1875 [==============>...............] - ETA: 6s - loss: 0.5705 - accuracy: 0.7944
968/1875 [==============>...............] - ETA: 5s - loss: 0.5703 - accuracy: 0.7944
976/1875 [==============>...............] - ETA: 5s - loss: 0.5691 - accuracy: 0.7949
984/1875 [==============>...............] - ETA: 5s - loss: 0.5683 - accuracy: 0.7951
992/1875 [==============>...............] - ETA: 5s - loss: 0.5664 - accuracy: 0.7957
1000/1875 [===============>..............] - ETA: 5s - loss: 0.5655 - accuracy: 0.7960
1008/1875 [===============>..............] - ETA: 5s - loss: 0.5650 - accuracy: 0.7962
1016/1875 [===============>..............] - ETA: 5s - loss: 0.5641 - accuracy: 0.7965
1024/1875 [===============>..............] - ETA: 5s - loss: 0.5625 - accuracy: 0.7971
1032/1875 [===============>..............] - ETA: 5s - loss: 0.5614 - accuracy: 0.7975
1040/1875 [===============>..............] - ETA: 5s - loss: 0.5605 - accuracy: 0.7981
1048/1875 [===============>..............] - ETA: 5s - loss: 0.5593 - accuracy: 0.7985
1057/1875 [===============>..............] - ETA: 5s - loss: 0.5592 - accuracy: 0.7990
1065/1875 [================>.............] - ETA: 5s - loss: 0.5586 - accuracy: 0.7991
1073/1875 [================>.............] - ETA: 5s - loss: 0.5576 - accuracy: 0.7994
1081/1875 [================>.............] - ETA: 5s - loss: 0.5569 - accuracy: 0.7996
1089/1875 [================>.............] - ETA: 5s - loss: 0.5563 - accuracy: 0.7998
1098/1875 [================>.............] - ETA: 5s - loss: 0.5554 - accuracy: 0.7998
1106/1875 [================>.............] - ETA: 5s - loss: 0.5543 - accuracy: 0.8001
1114/1875 [================>.............] - ETA: 5s - loss: 0.5539 - accuracy: 0.8003
1122/1875 [================>.............] - ETA: 4s - loss: 0.5528 - accuracy: 0.8007
1130/1875 [=================>............] - ETA: 4s - loss: 0.5527 - accuracy: 0.8007
1138/1875 [=================>............] - ETA: 4s - loss: 0.5520 - accuracy: 0.8009
1146/1875 [=================>............] - ETA: 4s - loss: 0.5520 - accuracy: 0.8011
1154/1875 [=================>............] - ETA: 4s - loss: 0.5514 - accuracy: 0.8012
1162/1875 [=================>............] - ETA: 4s - loss: 0.5505 - accuracy: 0.8014
1170/1875 [=================>............] - ETA: 4s - loss: 0.5501 - accuracy: 0.8015
1178/1875 [=================>............] - ETA: 4s - loss: 0.5500 - accuracy: 0.8016
1186/1875 [=================>............] - ETA: 4s - loss: 0.5483 - accuracy: 0.8023
1195/1875 [==================>...........] - ETA: 4s - loss: 0.5473 - accuracy: 0.8026
1204/1875 [==================>...........] - ETA: 4s - loss: 0.5468 - accuracy: 0.8028
1213/1875 [==================>...........] - ETA: 4s - loss: 0.5467 - accuracy: 0.8028
1222/1875 [==================>...........] - ETA: 4s - loss: 0.5456 - accuracy: 0.8033
1231/1875 [==================>...........] - ETA: 4s - loss: 0.5450 - accuracy: 0.8034
1240/1875 [==================>...........] - ETA: 4s - loss: 0.5439 - accuracy: 0.8039
1249/1875 [==================>...........] - ETA: 4s - loss: 0.5434 - accuracy: 0.8040
1258/1875 [===================>..........] - ETA: 4s - loss: 0.5429 - accuracy: 0.8042
1267/1875 [===================>..........] - ETA: 3s - loss: 0.5423 - accuracy: 0.8042
1276/1875 [===================>..........] - ETA: 3s - loss: 0.5415 - accuracy: 0.8044
1285/1875 [===================>..........] - ETA: 3s - loss: 0.5409 - accuracy: 0.8046
1293/1875 [===================>..........] - ETA: 3s - loss: 0.5416 - accuracy: 0.8046
1301/1875 [===================>..........] - ETA: 3s - loss: 0.5408 - accuracy: 0.8049
1309/1875 [===================>..........] - ETA: 3s - loss: 0.5394 - accuracy: 0.8053
1318/1875 [====================>.........] - ETA: 3s - loss: 0.5392 - accuracy: 0.8056
1327/1875 [====================>.........] - ETA: 3s - loss: 0.5382 - accuracy: 0.8060
1336/1875 [====================>.........] - ETA: 3s - loss: 0.5371 - accuracy: 0.8063
1345/1875 [====================>.........] - ETA: 3s - loss: 0.5368 - accuracy: 0.8064
1354/1875 [====================>.........] - ETA: 3s - loss: 0.5358 - accuracy: 0.8068
1363/1875 [====================>.........] - ETA: 3s - loss: 0.5358 - accuracy: 0.8070
1371/1875 [====================>.........] - ETA: 3s - loss: 0.5355 - accuracy: 0.8072
1379/1875 [=====================>........] - ETA: 3s - loss: 0.5349 - accuracy: 0.8075
1387/1875 [=====================>........] - ETA: 3s - loss: 0.5344 - accuracy: 0.8078
1395/1875 [=====================>........] - ETA: 3s - loss: 0.5339 - accuracy: 0.8082
1403/1875 [=====================>........] - ETA: 3s - loss: 0.5338 - accuracy: 0.8083
1411/1875 [=====================>........] - ETA: 3s - loss: 0.5342 - accuracy: 0.8080
1419/1875 [=====================>........] - ETA: 2s - loss: 0.5337 - accuracy: 0.8083
1427/1875 [=====================>........] - ETA: 2s - loss: 0.5333 - accuracy: 0.8083
1435/1875 [=====================>........] - ETA: 2s - loss: 0.5333 - accuracy: 0.8084
1443/1875 [======================>.......] - ETA: 2s - loss: 0.5329 - accuracy: 0.8084
1451/1875 [======================>.......] - ETA: 2s - loss: 0.5326 - accuracy: 0.8084
1459/1875 [======================>.......] - ETA: 2s - loss: 0.5323 - accuracy: 0.8085
1467/1875 [======================>.......] - ETA: 2s - loss: 0.5326 - accuracy: 0.8084
1475/1875 [======================>.......] - ETA: 2s - loss: 0.5319 - accuracy: 0.8086
1483/1875 [======================>.......] - ETA: 2s - loss: 0.5312 - accuracy: 0.8089
1491/1875 [======================>.......] - ETA: 2s - loss: 0.5307 - accuracy: 0.8091
1499/1875 [======================>.......] - ETA: 2s - loss: 0.5298 - accuracy: 0.8095
1507/1875 [=======================>......] - ETA: 2s - loss: 0.5293 - accuracy: 0.8097
1514/1875 [=======================>......] - ETA: 2s - loss: 0.5285 - accuracy: 0.8099
1522/1875 [=======================>......] - ETA: 2s - loss: 0.5278 - accuracy: 0.8102
1530/1875 [=======================>......] - ETA: 2s - loss: 0.5278 - accuracy: 0.8103
1538/1875 [=======================>......] - ETA: 2s - loss: 0.5276 - accuracy: 0.8103
1546/1875 [=======================>......] - ETA: 2s - loss: 0.5269 - accuracy: 0.8105
1554/1875 [=======================>......] - ETA: 2s - loss: 0.5262 - accuracy: 0.8108
1562/1875 [=======================>......] - ETA: 2s - loss: 0.5259 - accuracy: 0.8109
1570/1875 [========================>.....] - ETA: 1s - loss: 0.5260 - accuracy: 0.8109
1579/1875 [========================>.....] - ETA: 1s - loss: 0.5250 - accuracy: 0.8112
1588/1875 [========================>.....] - ETA: 1s - loss: 0.5242 - accuracy: 0.8114
1597/1875 [========================>.....] - ETA: 1s - loss: 0.5241 - accuracy: 0.8115
1606/1875 [========================>.....] - ETA: 1s - loss: 0.5231 - accuracy: 0.8118
1615/1875 [========================>.....] - ETA: 1s - loss: 0.5227 - accuracy: 0.8120
1624/1875 [========================>.....] - ETA: 1s - loss: 0.5219 - accuracy: 0.8123
1633/1875 [=========================>....] - ETA: 1s - loss: 0.5216 - accuracy: 0.8126
1642/1875 [=========================>....] - ETA: 1s - loss: 0.5214 - accuracy: 0.8127
1651/1875 [=========================>....] - ETA: 1s - loss: 0.5212 - accuracy: 0.8128
1660/1875 [=========================>....] - ETA: 1s - loss: 0.5204 - accuracy: 0.8130
1669/1875 [=========================>....] - ETA: 1s - loss: 0.5206 - accuracy: 0.8130
1677/1875 [=========================>....] - ETA: 1s - loss: 0.5199 - accuracy: 0.8133
1685/1875 [=========================>....] - ETA: 1s - loss: 0.5190 - accuracy: 0.8136
1694/1875 [==========================>...] - ETA: 1s - loss: 0.5188 - accuracy: 0.8138
1703/1875 [==========================>...] - ETA: 1s - loss: 0.5182 - accuracy: 0.8142
1711/1875 [==========================>...] - ETA: 1s - loss: 0.5174 - accuracy: 0.8144
1720/1875 [==========================>...] - ETA: 1s - loss: 0.5171 - accuracy: 0.8146
1729/1875 [==========================>...] - ETA: 0s - loss: 0.5169 - accuracy: 0.8146
1738/1875 [==========================>...] - ETA: 0s - loss: 0.5165 - accuracy: 0.8146
1746/1875 [==========================>...] - ETA: 0s - loss: 0.5162 - accuracy: 0.8148
1755/1875 [===========================>..] - ETA: 0s - loss: 0.5156 - accuracy: 0.8150
1764/1875 [===========================>..] - ETA: 0s - loss: 0.5149 - accuracy: 0.8152
1773/1875 [===========================>..] - ETA: 0s - loss: 0.5149 - accuracy: 0.8153
1781/1875 [===========================>..] - ETA: 0s - loss: 0.5148 - accuracy: 0.8154
1790/1875 [===========================>..] - ETA: 0s - loss: 0.5146 - accuracy: 0.8155
1799/1875 [===========================>..] - ETA: 0s - loss: 0.5144 - accuracy: 0.8155
1808/1875 [===========================>..] - ETA: 0s - loss: 0.5137 - accuracy: 0.8157
1817/1875 [============================>.] - ETA: 0s - loss: 0.5140 - accuracy: 0.8159
1825/1875 [============================>.] - ETA: 0s - loss: 0.5130 - accuracy: 0.8162
1834/1875 [============================>.] - ETA: 0s - loss: 0.5130 - accuracy: 0.8162
1843/1875 [============================>.] - ETA: 0s - loss: 0.5123 - accuracy: 0.8166
1852/1875 [============================>.] - ETA: 0s - loss: 0.5119 - accuracy: 0.8168
1861/1875 [============================>.] - ETA: 0s - loss: 0.5114 - accuracy: 0.8170
1870/1875 [============================>.] - ETA: 0s - loss: 0.5106 - accuracy: 0.8172
1875/1875 [==============================] - 13s 6ms/step - loss: 0.5103 - accuracy: 0.8174
Epoch 2/3
1/1875 [..............................] - ETA: 12s - loss: 0.1696 - accuracy: 0.9375
10/1875 [..............................] - ETA: 11s - loss: 0.4546 - accuracy: 0.8250
18/1875 [..............................] - ETA: 11s - loss: 0.4349 - accuracy: 0.8351
27/1875 [..............................] - ETA: 11s - loss: 0.4353 - accuracy: 0.8391
35/1875 [..............................] - ETA: 11s - loss: 0.4150 - accuracy: 0.8464
43/1875 [..............................] - ETA: 11s - loss: 0.4317 - accuracy: 0.8452
51/1875 [..............................] - ETA: 11s - loss: 0.4179 - accuracy: 0.8517
59/1875 [..............................] - ETA: 11s - loss: 0.4145 - accuracy: 0.8549
67/1875 [>.............................] - ETA: 11s - loss: 0.4155 - accuracy: 0.8521
76/1875 [>.............................] - ETA: 11s - loss: 0.4053 - accuracy: 0.8557
85/1875 [>.............................] - ETA: 11s - loss: 0.4014 - accuracy: 0.8577
94/1875 [>.............................] - ETA: 11s - loss: 0.3936 - accuracy: 0.8614
103/1875 [>.............................] - ETA: 11s - loss: 0.3949 - accuracy: 0.8623
111/1875 [>.............................] - ETA: 11s - loss: 0.4062 - accuracy: 0.8598
119/1875 [>.............................] - ETA: 11s - loss: 0.4127 - accuracy: 0.8571
127/1875 [=>............................] - ETA: 11s - loss: 0.4083 - accuracy: 0.8590
135/1875 [=>............................] - ETA: 10s - loss: 0.4162 - accuracy: 0.8567
143/1875 [=>............................] - ETA: 10s - loss: 0.4199 - accuracy: 0.8566
151/1875 [=>............................] - ETA: 10s - loss: 0.4227 - accuracy: 0.8562
159/1875 [=>............................] - ETA: 10s - loss: 0.4239 - accuracy: 0.8561
167/1875 [=>............................] - ETA: 10s - loss: 0.4230 - accuracy: 0.8567
175/1875 [=>............................] - ETA: 10s - loss: 0.4219 - accuracy: 0.8561
183/1875 [=>............................] - ETA: 10s - loss: 0.4241 - accuracy: 0.8560
191/1875 [==>...........................] - ETA: 10s - loss: 0.4213 - accuracy: 0.8573
199/1875 [==>...........................] - ETA: 10s - loss: 0.4253 - accuracy: 0.8568
207/1875 [==>...........................] - ETA: 10s - loss: 0.4272 - accuracy: 0.8557
215/1875 [==>...........................] - ETA: 10s - loss: 0.4310 - accuracy: 0.8542
223/1875 [==>...........................] - ETA: 10s - loss: 0.4307 - accuracy: 0.8538
231/1875 [==>...........................] - ETA: 10s - loss: 0.4271 - accuracy: 0.8546
239/1875 [==>...........................] - ETA: 10s - loss: 0.4253 - accuracy: 0.8550
247/1875 [==>...........................] - ETA: 10s - loss: 0.4240 - accuracy: 0.8551
255/1875 [===>..........................] - ETA: 10s - loss: 0.4230 - accuracy: 0.8554
263/1875 [===>..........................] - ETA: 10s - loss: 0.4243 - accuracy: 0.8538
271/1875 [===>..........................] - ETA: 10s - loss: 0.4256 - accuracy: 0.8533
279/1875 [===>..........................] - ETA: 10s - loss: 0.4300 - accuracy: 0.8526
287/1875 [===>..........................] - ETA: 10s - loss: 0.4270 - accuracy: 0.8527
295/1875 [===>..........................] - ETA: 10s - loss: 0.4260 - accuracy: 0.8529
303/1875 [===>..........................] - ETA: 10s - loss: 0.4290 - accuracy: 0.8519
311/1875 [===>..........................] - ETA: 10s - loss: 0.4280 - accuracy: 0.8520
319/1875 [====>.........................] - ETA: 10s - loss: 0.4293 - accuracy: 0.8523
327/1875 [====>.........................] - ETA: 10s - loss: 0.4309 - accuracy: 0.8522
335/1875 [====>.........................] - ETA: 10s - loss: 0.4284 - accuracy: 0.8531
343/1875 [====>.........................] - ETA: 9s - loss: 0.4282 - accuracy: 0.8530
351/1875 [====>.........................] - ETA: 9s - loss: 0.4296 - accuracy: 0.8520
359/1875 [====>.........................] - ETA: 9s - loss: 0.4284 - accuracy: 0.8525
367/1875 [====>.........................] - ETA: 9s - loss: 0.4288 - accuracy: 0.8525
375/1875 [=====>........................] - ETA: 9s - loss: 0.4307 - accuracy: 0.8525
383/1875 [=====>........................] - ETA: 9s - loss: 0.4295 - accuracy: 0.8532
391/1875 [=====>........................] - ETA: 9s - loss: 0.4275 - accuracy: 0.8545
399/1875 [=====>........................] - ETA: 9s - loss: 0.4271 - accuracy: 0.8541
407/1875 [=====>........................] - ETA: 9s - loss: 0.4277 - accuracy: 0.8537
415/1875 [=====>........................] - ETA: 9s - loss: 0.4276 - accuracy: 0.8542
423/1875 [=====>........................] - ETA: 9s - loss: 0.4282 - accuracy: 0.8542
431/1875 [=====>........................] - ETA: 9s - loss: 0.4287 - accuracy: 0.8538
439/1875 [======>.......................] - ETA: 9s - loss: 0.4286 - accuracy: 0.8541
447/1875 [======>.......................] - ETA: 9s - loss: 0.4274 - accuracy: 0.8543
455/1875 [======>.......................] - ETA: 9s - loss: 0.4265 - accuracy: 0.8546
463/1875 [======>.......................] - ETA: 9s - loss: 0.4264 - accuracy: 0.8544
471/1875 [======>.......................] - ETA: 9s - loss: 0.4273 - accuracy: 0.8539
479/1875 [======>.......................] - ETA: 9s - loss: 0.4261 - accuracy: 0.8538
487/1875 [======>.......................] - ETA: 9s - loss: 0.4244 - accuracy: 0.8545
495/1875 [======>.......................] - ETA: 8s - loss: 0.4253 - accuracy: 0.8544
503/1875 [=======>......................] - ETA: 8s - loss: 0.4239 - accuracy: 0.8548
511/1875 [=======>......................] - ETA: 8s - loss: 0.4231 - accuracy: 0.8552
519/1875 [=======>......................] - ETA: 8s - loss: 0.4221 - accuracy: 0.8554
527/1875 [=======>......................] - ETA: 8s - loss: 0.4226 - accuracy: 0.8556
535/1875 [=======>......................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8554
543/1875 [=======>......................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8554
551/1875 [=======>......................] - ETA: 8s - loss: 0.4232 - accuracy: 0.8549
559/1875 [=======>......................] - ETA: 8s - loss: 0.4223 - accuracy: 0.8550
567/1875 [========>.....................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8549
575/1875 [========>.....................] - ETA: 8s - loss: 0.4227 - accuracy: 0.8548
583/1875 [========>.....................] - ETA: 8s - loss: 0.4230 - accuracy: 0.8548
591/1875 [========>.....................] - ETA: 8s - loss: 0.4235 - accuracy: 0.8544
599/1875 [========>.....................] - ETA: 8s - loss: 0.4226 - accuracy: 0.8548
607/1875 [========>.....................] - ETA: 8s - loss: 0.4215 - accuracy: 0.8554
615/1875 [========>.....................] - ETA: 8s - loss: 0.4209 - accuracy: 0.8553
623/1875 [========>.....................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8549
631/1875 [=========>....................] - ETA: 8s - loss: 0.4215 - accuracy: 0.8551
639/1875 [=========>....................] - ETA: 8s - loss: 0.4224 - accuracy: 0.8553
647/1875 [=========>....................] - ETA: 7s - loss: 0.4225 - accuracy: 0.8554
655/1875 [=========>....................] - ETA: 7s - loss: 0.4227 - accuracy: 0.8553
663/1875 [=========>....................] - ETA: 7s - loss: 0.4209 - accuracy: 0.8560
671/1875 [=========>....................] - ETA: 7s - loss: 0.4210 - accuracy: 0.8556
679/1875 [=========>....................] - ETA: 7s - loss: 0.4196 - accuracy: 0.8557
687/1875 [=========>....................] - ETA: 7s - loss: 0.4194 - accuracy: 0.8556
695/1875 [==========>...................] - ETA: 7s - loss: 0.4211 - accuracy: 0.8555
703/1875 [==========>...................] - ETA: 7s - loss: 0.4206 - accuracy: 0.8554
711/1875 [==========>...................] - ETA: 7s - loss: 0.4209 - accuracy: 0.8551
719/1875 [==========>...................] - ETA: 7s - loss: 0.4201 - accuracy: 0.8554
727/1875 [==========>...................] - ETA: 7s - loss: 0.4186 - accuracy: 0.8558
735/1875 [==========>...................] - ETA: 7s - loss: 0.4186 - accuracy: 0.8558
743/1875 [==========>...................] - ETA: 7s - loss: 0.4181 - accuracy: 0.8560
751/1875 [===========>..................] - ETA: 7s - loss: 0.4185 - accuracy: 0.8558
759/1875 [===========>..................] - ETA: 7s - loss: 0.4193 - accuracy: 0.8556
767/1875 [===========>..................] - ETA: 7s - loss: 0.4181 - accuracy: 0.8561
775/1875 [===========>..................] - ETA: 7s - loss: 0.4176 - accuracy: 0.8559
783/1875 [===========>..................] - ETA: 7s - loss: 0.4179 - accuracy: 0.8560
790/1875 [===========>..................] - ETA: 7s - loss: 0.4171 - accuracy: 0.8562
798/1875 [===========>..................] - ETA: 7s - loss: 0.4163 - accuracy: 0.8566
806/1875 [===========>..................] - ETA: 6s - loss: 0.4161 - accuracy: 0.8567
814/1875 [============>.................] - ETA: 6s - loss: 0.4163 - accuracy: 0.8566
822/1875 [============>.................] - ETA: 6s - loss: 0.4163 - accuracy: 0.8566
830/1875 [============>.................] - ETA: 6s - loss: 0.4160 - accuracy: 0.8566
838/1875 [============>.................] - ETA: 6s - loss: 0.4167 - accuracy: 0.8565
846/1875 [============>.................] - ETA: 6s - loss: 0.4165 - accuracy: 0.8566
854/1875 [============>.................] - ETA: 6s - loss: 0.4159 - accuracy: 0.8568
862/1875 [============>.................] - ETA: 6s - loss: 0.4168 - accuracy: 0.8565
870/1875 [============>.................] - ETA: 6s - loss: 0.4173 - accuracy: 0.8563
878/1875 [=============>................] - ETA: 6s - loss: 0.4176 - accuracy: 0.8562
886/1875 [=============>................] - ETA: 6s - loss: 0.4172 - accuracy: 0.8565
894/1875 [=============>................] - ETA: 6s - loss: 0.4173 - accuracy: 0.8564
902/1875 [=============>................] - ETA: 6s - loss: 0.4169 - accuracy: 0.8563
910/1875 [=============>................] - ETA: 6s - loss: 0.4176 - accuracy: 0.8561
918/1875 [=============>................] - ETA: 6s - loss: 0.4183 - accuracy: 0.8562
926/1875 [=============>................] - ETA: 6s - loss: 0.4179 - accuracy: 0.8565
933/1875 [=============>................] - ETA: 6s - loss: 0.4176 - accuracy: 0.8564
941/1875 [==============>...............] - ETA: 6s - loss: 0.4178 - accuracy: 0.8563
949/1875 [==============>...............] - ETA: 6s - loss: 0.4180 - accuracy: 0.8560
957/1875 [==============>...............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8556
965/1875 [==============>...............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8559
973/1875 [==============>...............] - ETA: 5s - loss: 0.4179 - accuracy: 0.8559
981/1875 [==============>...............] - ETA: 5s - loss: 0.4173 - accuracy: 0.8560
989/1875 [==============>...............] - ETA: 5s - loss: 0.4186 - accuracy: 0.8559
997/1875 [==============>...............] - ETA: 5s - loss: 0.4190 - accuracy: 0.8558
1005/1875 [===============>..............] - ETA: 5s - loss: 0.4186 - accuracy: 0.8560
1013/1875 [===============>..............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8563
1021/1875 [===============>..............] - ETA: 5s - loss: 0.4180 - accuracy: 0.8562
1029/1875 [===============>..............] - ETA: 5s - loss: 0.4181 - accuracy: 0.8561
1037/1875 [===============>..............] - ETA: 5s - loss: 0.4174 - accuracy: 0.8564
1045/1875 [===============>..............] - ETA: 5s - loss: 0.4171 - accuracy: 0.8567
1053/1875 [===============>..............] - ETA: 5s - loss: 0.4167 - accuracy: 0.8568
1061/1875 [===============>..............] - ETA: 5s - loss: 0.4175 - accuracy: 0.8566
1069/1875 [================>.............] - ETA: 5s - loss: 0.4176 - accuracy: 0.8566
1077/1875 [================>.............] - ETA: 5s - loss: 0.4181 - accuracy: 0.8563
1085/1875 [================>.............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8562
1093/1875 [================>.............] - ETA: 5s - loss: 0.4184 - accuracy: 0.8560
1101/1875 [================>.............] - ETA: 5s - loss: 0.4184 - accuracy: 0.8560
1109/1875 [================>.............] - ETA: 4s - loss: 0.4179 - accuracy: 0.8563
1117/1875 [================>.............] - ETA: 4s - loss: 0.4180 - accuracy: 0.8561
1125/1875 [=================>............] - ETA: 4s - loss: 0.4183 - accuracy: 0.8561
1133/1875 [=================>............] - ETA: 4s - loss: 0.4181 - accuracy: 0.8562
1141/1875 [=================>............] - ETA: 4s - loss: 0.4177 - accuracy: 0.8563
1149/1875 [=================>............] - ETA: 4s - loss: 0.4173 - accuracy: 0.8563
1157/1875 [=================>............] - ETA: 4s - loss: 0.4168 - accuracy: 0.8566
1160/1875 [=================>............] - ETA: 34s - loss: 0.4168 - accuracy: 0.8565
1166/1875 [=================>............] - ETA: 33s - loss: 0.4170 - accuracy: 0.8566
1173/1875 [=================>............] - ETA: 33s - loss: 0.4167 - accuracy: 0.8568
1181/1875 [=================>............] - ETA: 32s - loss: 0.4164 - accuracy: 0.8569
1189/1875 [==================>...........] - ETA: 31s - loss: 0.4166 - accuracy: 0.8570
1197/1875 [==================>...........] - ETA: 31s - loss: 0.4163 - accuracy: 0.8570
1205/1875 [==================>...........] - ETA: 30s - loss: 0.4163 - accuracy: 0.8570
1212/1875 [==================>...........] - ETA: 30s - loss: 0.4170 - accuracy: 0.8567
1219/1875 [==================>...........] - ETA: 29s - loss: 0.4175 - accuracy: 0.8566
1227/1875 [==================>...........] - ETA: 29s - loss: 0.4179 - accuracy: 0.8566
1235/1875 [==================>...........] - ETA: 28s - loss: 0.4170 - accuracy: 0.8566
1243/1875 [==================>...........] - ETA: 28s - loss: 0.4177 - accuracy: 0.8563
1251/1875 [===================>..........] - ETA: 27s - loss: 0.4182 - accuracy: 0.8563
1259/1875 [===================>..........] - ETA: 27s - loss: 0.4183 - accuracy: 0.8560
1267/1875 [===================>..........] - ETA: 26s - loss: 0.4184 - accuracy: 0.8558
1275/1875 [===================>..........] - ETA: 26s - loss: 0.4188 - accuracy: 0.8558
1283/1875 [===================>..........] - ETA: 25s - loss: 0.4182 - accuracy: 0.8557
1291/1875 [===================>..........] - ETA: 25s - loss: 0.4191 - accuracy: 0.8557
1299/1875 [===================>..........] - ETA: 24s - loss: 0.4186 - accuracy: 0.8556
1307/1875 [===================>..........] - ETA: 24s - loss: 0.4186 - accuracy: 0.8558
1315/1875 [====================>.........] - ETA: 23s - loss: 0.4197 - accuracy: 0.8555
1322/1875 [====================>.........] - ETA: 23s - loss: 0.4197 - accuracy: 0.8555
1330/1875 [====================>.........] - ETA: 23s - loss: 0.4194 - accuracy: 0.8555
1338/1875 [====================>.........] - ETA: 22s - loss: 0.4189 - accuracy: 0.8555
1346/1875 [====================>.........] - ETA: 22s - loss: 0.4197 - accuracy: 0.8551
1353/1875 [====================>.........] - ETA: 21s - loss: 0.4197 - accuracy: 0.8552
1361/1875 [====================>.........] - ETA: 21s - loss: 0.4198 - accuracy: 0.8552
1369/1875 [====================>.........] - ETA: 20s - loss: 0.4191 - accuracy: 0.8554
1377/1875 [=====================>........] - ETA: 20s - loss: 0.4199 - accuracy: 0.8552
1385/1875 [=====================>........] - ETA: 20s - loss: 0.4198 - accuracy: 0.8554
1393/1875 [=====================>........] - ETA: 19s - loss: 0.4195 - accuracy: 0.8555
1400/1875 [=====================>........] - ETA: 19s - loss: 0.4193 - accuracy: 0.8555
1408/1875 [=====================>........] - ETA: 18s - loss: 0.4195 - accuracy: 0.8554
1416/1875 [=====================>........] - ETA: 18s - loss: 0.4192 - accuracy: 0.8556
1424/1875 [=====================>........] - ETA: 18s - loss: 0.4201 - accuracy: 0.8555
1431/1875 [=====================>........] - ETA: 17s - loss: 0.4199 - accuracy: 0.8555
1439/1875 [======================>.......] - ETA: 17s - loss: 0.4197 - accuracy: 0.8555
1447/1875 [======================>.......] - ETA: 16s - loss: 0.4206 - accuracy: 0.8556
1455/1875 [======================>.......] - ETA: 16s - loss: 0.4204 - accuracy: 0.8556
1463/1875 [======================>.......] - ETA: 16s - loss: 0.4200 - accuracy: 0.8557
1471/1875 [======================>.......] - ETA: 15s - loss: 0.4206 - accuracy: 0.8556
1479/1875 [======================>.......] - ETA: 15s - loss: 0.4204 - accuracy: 0.8556
1487/1875 [======================>.......] - ETA: 14s - loss: 0.4204 - accuracy: 0.8556
1495/1875 [======================>.......] - ETA: 14s - loss: 0.4197 - accuracy: 0.8557
1503/1875 [=======================>......] - ETA: 14s - loss: 0.4204 - accuracy: 0.8555
1511/1875 [=======================>......] - ETA: 13s - loss: 0.4209 - accuracy: 0.8553
1519/1875 [=======================>......] - ETA: 13s - loss: 0.4212 - accuracy: 0.8552
1527/1875 [=======================>......] - ETA: 13s - loss: 0.4208 - accuracy: 0.8552
1535/1875 [=======================>......] - ETA: 12s - loss: 0.4205 - accuracy: 0.8553
1543/1875 [=======================>......] - ETA: 12s - loss: 0.4208 - accuracy: 0.8552
1551/1875 [=======================>......] - ETA: 12s - loss: 0.4205 - accuracy: 0.8552
1559/1875 [=======================>......] - ETA: 11s - loss: 0.4209 - accuracy: 0.8552
1567/1875 [========================>.....] - ETA: 11s - loss: 0.4212 - accuracy: 0.8553
1575/1875 [========================>.....] - ETA: 11s - loss: 0.4208 - accuracy: 0.8553
1583/1875 [========================>.....] - ETA: 10s - loss: 0.4202 - accuracy: 0.8556
1591/1875 [========================>.....] - ETA: 10s - loss: 0.4203 - accuracy: 0.8554
1599/1875 [========================>.....] - ETA: 10s - loss: 0.4198 - accuracy: 0.8555
1607/1875 [========================>.....] - ETA: 9s - loss: 0.4204 - accuracy: 0.8555
1615/1875 [========================>.....] - ETA: 9s - loss: 0.4208 - accuracy: 0.8554
1623/1875 [========================>.....] - ETA: 9s - loss: 0.4207 - accuracy: 0.8555
1631/1875 [=========================>....] - ETA: 8s - loss: 0.4203 - accuracy: 0.8555
1639/1875 [=========================>....] - ETA: 8s - loss: 0.4201 - accuracy: 0.8555
1647/1875 [=========================>....] - ETA: 8s - loss: 0.4204 - accuracy: 0.8554
1654/1875 [=========================>....] - ETA: 7s - loss: 0.4203 - accuracy: 0.8555
1662/1875 [=========================>....] - ETA: 7s - loss: 0.4199 - accuracy: 0.8557
1670/1875 [=========================>....] - ETA: 7s - loss: 0.4197 - accuracy: 0.8557
1678/1875 [=========================>....] - ETA: 6s - loss: 0.4197 - accuracy: 0.8557
1686/1875 [=========================>....] - ETA: 6s - loss: 0.4196 - accuracy: 0.8557
1694/1875 [==========================>...] - ETA: 6s - loss: 0.4196 - accuracy: 0.8556
1702/1875 [==========================>...] - ETA: 5s - loss: 0.4195 - accuracy: 0.8556
1710/1875 [==========================>...] - ETA: 5s - loss: 0.4191 - accuracy: 0.8558
1718/1875 [==========================>...] - ETA: 5s - loss: 0.4190 - accuracy: 0.8558
1726/1875 [==========================>...] - ETA: 5s - loss: 0.4186 - accuracy: 0.8559
1734/1875 [==========================>...] - ETA: 4s - loss: 0.4186 - accuracy: 0.8558
1743/1875 [==========================>...] - ETA: 4s - loss: 0.4185 - accuracy: 0.8557
1751/1875 [===========================>..] - ETA: 4s - loss: 0.4184 - accuracy: 0.8557
1759/1875 [===========================>..] - ETA: 3s - loss: 0.4180 - accuracy: 0.8557
1767/1875 [===========================>..] - ETA: 3s - loss: 0.4180 - accuracy: 0.8556
1775/1875 [===========================>..] - ETA: 3s - loss: 0.4183 - accuracy: 0.8557
1783/1875 [===========================>..] - ETA: 3s - loss: 0.4183 - accuracy: 0.8557
1791/1875 [===========================>..] - ETA: 2s - loss: 0.4184 - accuracy: 0.8557
1799/1875 [===========================>..] - ETA: 2s - loss: 0.4185 - accuracy: 0.8556
1807/1875 [===========================>..] - ETA: 2s - loss: 0.4188 - accuracy: 0.8556
1815/1875 [============================>.] - ETA: 1s - loss: 0.4191 - accuracy: 0.8555
1823/1875 [============================>.] - ETA: 1s - loss: 0.4192 - accuracy: 0.8555
1831/1875 [============================>.] - ETA: 1s - loss: 0.4192 - accuracy: 0.8554
1839/1875 [============================>.] - ETA: 1s - loss: 0.4185 - accuracy: 0.8556
1847/1875 [============================>.] - ETA: 0s - loss: 0.4177 - accuracy: 0.8558
1855/1875 [============================>.] - ETA: 0s - loss: 0.4179 - accuracy: 0.8557
1863/1875 [============================>.] - ETA: 0s - loss: 0.4184 - accuracy: 0.8556
1871/1875 [============================>.] - ETA: 0s - loss: 0.4184 - accuracy: 0.8556
1875/1875 [==============================] - 60s 32ms/step - loss: 0.4183 - accuracy: 0.8556
Epoch 3/3
1/1875 [..............................] - ETA: 13s - loss: 0.4762 - accuracy: 0.8438
9/1875 [..............................] - ETA: 12s - loss: 0.3872 - accuracy: 0.8819
17/1875 [..............................] - ETA: 12s - loss: 0.4069 - accuracy: 0.8695
25/1875 [..............................] - ETA: 12s - loss: 0.3952 - accuracy: 0.8625
33/1875 [..............................] - ETA: 12s - loss: 0.3946 - accuracy: 0.8627
41/1875 [..............................] - ETA: 12s - loss: 0.4041 - accuracy: 0.8575
49/1875 [..............................] - ETA: 12s - loss: 0.3877 - accuracy: 0.8648
57/1875 [..............................] - ETA: 12s - loss: 0.3806 - accuracy: 0.8668
65/1875 [>.............................] - ETA: 12s - loss: 0.3942 - accuracy: 0.8654
73/1875 [>.............................] - ETA: 11s - loss: 0.3962 - accuracy: 0.8660
81/1875 [>.............................] - ETA: 11s - loss: 0.3860 - accuracy: 0.8692
89/1875 [>.............................] - ETA: 11s - loss: 0.3862 - accuracy: 0.8697
97/1875 [>.............................] - ETA: 11s - loss: 0.3872 - accuracy: 0.8679
105/1875 [>.............................] - ETA: 11s - loss: 0.3864 - accuracy: 0.8682
113/1875 [>.............................] - ETA: 11s - loss: 0.3892 - accuracy: 0.8673
121/1875 [>.............................] - ETA: 11s - loss: 0.3910 - accuracy: 0.8665
129/1875 [=>............................] - ETA: 11s - loss: 0.3909 - accuracy: 0.8668
137/1875 [=>............................] - ETA: 11s - loss: 0.3953 - accuracy: 0.8638
145/1875 [=>............................] - ETA: 11s - loss: 0.3908 - accuracy: 0.8653
153/1875 [=>............................] - ETA: 11s - loss: 0.3887 - accuracy: 0.8664
161/1875 [=>............................] - ETA: 11s - loss: 0.3859 - accuracy: 0.8667
169/1875 [=>............................] - ETA: 11s - loss: 0.3817 - accuracy: 0.8674
177/1875 [=>............................] - ETA: 11s - loss: 0.3833 - accuracy: 0.8665
185/1875 [=>............................] - ETA: 11s - loss: 0.3811 - accuracy: 0.8671
193/1875 [==>...........................] - ETA: 11s - loss: 0.3773 - accuracy: 0.8687
201/1875 [==>...........................] - ETA: 10s - loss: 0.3760 - accuracy: 0.8696
209/1875 [==>...........................] - ETA: 10s - loss: 0.3798 - accuracy: 0.8690
217/1875 [==>...........................] - ETA: 10s - loss: 0.3843 - accuracy: 0.8681
225/1875 [==>...........................] - ETA: 10s - loss: 0.3807 - accuracy: 0.8694
233/1875 [==>...........................] - ETA: 10s - loss: 0.3836 - accuracy: 0.8695
241/1875 [==>...........................] - ETA: 10s - loss: 0.3863 - accuracy: 0.8692
249/1875 [==>...........................] - ETA: 10s - loss: 0.3859 - accuracy: 0.8691
257/1875 [===>..........................] - ETA: 10s - loss: 0.3830 - accuracy: 0.8695
265/1875 [===>..........................] - ETA: 10s - loss: 0.3873 - accuracy: 0.8704
273/1875 [===>..........................] - ETA: 10s - loss: 0.3887 - accuracy: 0.8703
281/1875 [===>..........................] - ETA: 10s - loss: 0.3878 - accuracy: 0.8700
289/1875 [===>..........................] - ETA: 10s - loss: 0.3866 - accuracy: 0.8700
297/1875 [===>..........................] - ETA: 10s - loss: 0.3900 - accuracy: 0.8695
305/1875 [===>..........................] - ETA: 10s - loss: 0.3886 - accuracy: 0.8701
313/1875 [====>.........................] - ETA: 10s - loss: 0.3881 - accuracy: 0.8699
321/1875 [====>.........................] - ETA: 10s - loss: 0.3893 - accuracy: 0.8698
329/1875 [====>.........................] - ETA: 10s - loss: 0.3875 - accuracy: 0.8705
337/1875 [====>.........................] - ETA: 10s - loss: 0.3867 - accuracy: 0.8705
345/1875 [====>.........................] - ETA: 9s - loss: 0.3862 - accuracy: 0.8699
353/1875 [====>.........................] - ETA: 9s - loss: 0.3871 - accuracy: 0.8692
362/1875 [====>.........................] - ETA: 9s - loss: 0.3867 - accuracy: 0.8688
370/1875 [====>.........................] - ETA: 9s - loss: 0.3863 - accuracy: 0.8692
378/1875 [=====>........................] - ETA: 9s - loss: 0.3881 - accuracy: 0.8690
386/1875 [=====>........................] - ETA: 9s - loss: 0.3917 - accuracy: 0.8678
394/1875 [=====>........................] - ETA: 9s - loss: 0.3908 - accuracy: 0.8675
402/1875 [=====>........................] - ETA: 9s - loss: 0.3938 - accuracy: 0.8673
410/1875 [=====>........................] - ETA: 9s - loss: 0.3929 - accuracy: 0.8673
418/1875 [=====>........................] - ETA: 9s - loss: 0.3925 - accuracy: 0.8677
426/1875 [=====>........................] - ETA: 9s - loss: 0.3928 - accuracy: 0.8681
434/1875 [=====>........................] - ETA: 9s - loss: 0.3915 - accuracy: 0.8686
442/1875 [======>.......................] - ETA: 9s - loss: 0.3919 - accuracy: 0.8684
450/1875 [======>.......................] - ETA: 9s - loss: 0.3931 - accuracy: 0.8681
458/1875 [======>.......................] - ETA: 9s - loss: 0.3929 - accuracy: 0.8682
466/1875 [======>.......................] - ETA: 9s - loss: 0.3933 - accuracy: 0.8678
474/1875 [======>.......................] - ETA: 9s - loss: 0.3925 - accuracy: 0.8681
482/1875 [======>.......................] - ETA: 9s - loss: 0.3929 - accuracy: 0.8681
490/1875 [======>.......................] - ETA: 9s - loss: 0.3921 - accuracy: 0.8682
498/1875 [======>.......................] - ETA: 8s - loss: 0.3915 - accuracy: 0.8685
506/1875 [=======>......................] - ETA: 8s - loss: 0.3907 - accuracy: 0.8686
514/1875 [=======>......................] - ETA: 8s - loss: 0.3911 - accuracy: 0.8689
522/1875 [=======>......................] - ETA: 8s - loss: 0.3922 - accuracy: 0.8681
530/1875 [=======>......................] - ETA: 8s - loss: 0.3933 - accuracy: 0.8676
538/1875 [=======>......................] - ETA: 8s - loss: 0.3924 - accuracy: 0.8682
546/1875 [=======>......................] - ETA: 8s - loss: 0.3926 - accuracy: 0.8682
554/1875 [=======>......................] - ETA: 8s - loss: 0.3938 - accuracy: 0.8679
562/1875 [=======>......................] - ETA: 8s - loss: 0.3933 - accuracy: 0.8677
569/1875 [========>.....................] - ETA: 8s - loss: 0.3951 - accuracy: 0.8676
577/1875 [========>.....................] - ETA: 8s - loss: 0.3938 - accuracy: 0.8679
585/1875 [========>.....................] - ETA: 8s - loss: 0.3943 - accuracy: 0.8678
593/1875 [========>.....................] - ETA: 8s - loss: 0.3943 - accuracy: 0.8679
601/1875 [========>.....................] - ETA: 8s - loss: 0.3951 - accuracy: 0.8678
609/1875 [========>.....................] - ETA: 8s - loss: 0.3958 - accuracy: 0.8676
617/1875 [========>.....................] - ETA: 8s - loss: 0.3946 - accuracy: 0.8680
625/1875 [=========>....................] - ETA: 8s - loss: 0.3947 - accuracy: 0.8679
633/1875 [=========>....................] - ETA: 8s - loss: 0.3933 - accuracy: 0.8681
641/1875 [=========>....................] - ETA: 8s - loss: 0.3931 - accuracy: 0.8678
649/1875 [=========>....................] - ETA: 8s - loss: 0.3937 - accuracy: 0.8672
657/1875 [=========>....................] - ETA: 7s - loss: 0.3933 - accuracy: 0.8669
665/1875 [=========>....................] - ETA: 7s - loss: 0.3927 - accuracy: 0.8669
673/1875 [=========>....................] - ETA: 7s - loss: 0.3933 - accuracy: 0.8668
681/1875 [=========>....................] - ETA: 7s - loss: 0.3941 - accuracy: 0.8666
689/1875 [==========>...................] - ETA: 7s - loss: 0.3951 - accuracy: 0.8665
697/1875 [==========>...................] - ETA: 7s - loss: 0.3950 - accuracy: 0.8662
705/1875 [==========>...................] - ETA: 7s - loss: 0.3949 - accuracy: 0.8664
713/1875 [==========>...................] - ETA: 7s - loss: 0.3947 - accuracy: 0.8661
721/1875 [==========>...................] - ETA: 7s - loss: 0.3953 - accuracy: 0.8660
729/1875 [==========>...................] - ETA: 7s - loss: 0.3945 - accuracy: 0.8664
737/1875 [==========>...................] - ETA: 7s - loss: 0.3944 - accuracy: 0.8665
745/1875 [==========>...................] - ETA: 7s - loss: 0.3948 - accuracy: 0.8662
753/1875 [===========>..................] - ETA: 7s - loss: 0.3943 - accuracy: 0.8665
761/1875 [===========>..................] - ETA: 7s - loss: 0.3953 - accuracy: 0.8660
769/1875 [===========>..................] - ETA: 7s - loss: 0.3969 - accuracy: 0.8656
777/1875 [===========>..................] - ETA: 7s - loss: 0.3974 - accuracy: 0.8655
785/1875 [===========>..................] - ETA: 7s - loss: 0.3983 - accuracy: 0.8652
794/1875 [===========>..................] - ETA: 7s - loss: 0.3986 - accuracy: 0.8652
802/1875 [===========>..................] - ETA: 6s - loss: 0.3979 - accuracy: 0.8654
810/1875 [===========>..................] - ETA: 6s - loss: 0.3978 - accuracy: 0.8654
818/1875 [============>.................] - ETA: 6s - loss: 0.3986 - accuracy: 0.8649
826/1875 [============>.................] - ETA: 6s - loss: 0.3989 - accuracy: 0.8647
834/1875 [============>.................] - ETA: 6s - loss: 0.3991 - accuracy: 0.8647
842/1875 [============>.................] - ETA: 6s - loss: 0.3986 - accuracy: 0.8646
850/1875 [============>.................] - ETA: 6s - loss: 0.3986 - accuracy: 0.8644
858/1875 [============>.................] - ETA: 6s - loss: 0.3991 - accuracy: 0.8645
866/1875 [============>.................] - ETA: 6s - loss: 0.3989 - accuracy: 0.8645
874/1875 [============>.................] - ETA: 6s - loss: 0.3991 - accuracy: 0.8644
882/1875 [=============>................] - ETA: 6s - loss: 0.3992 - accuracy: 0.8641
890/1875 [=============>................] - ETA: 6s - loss: 0.3987 - accuracy: 0.8642
898/1875 [=============>................] - ETA: 6s - loss: 0.3988 - accuracy: 0.8642
906/1875 [=============>................] - ETA: 6s - loss: 0.3977 - accuracy: 0.8644
914/1875 [=============>................] - ETA: 6s - loss: 0.3977 - accuracy: 0.8645
922/1875 [=============>................] - ETA: 6s - loss: 0.3976 - accuracy: 0.8645
930/1875 [=============>................] - ETA: 6s - loss: 0.3979 - accuracy: 0.8644
938/1875 [==============>...............] - ETA: 6s - loss: 0.3976 - accuracy: 0.8643
946/1875 [==============>...............] - ETA: 6s - loss: 0.3987 - accuracy: 0.8641
954/1875 [==============>...............] - ETA: 5s - loss: 0.3989 - accuracy: 0.8642
962/1875 [==============>...............] - ETA: 5s - loss: 0.3992 - accuracy: 0.8642
970/1875 [==============>...............] - ETA: 5s - loss: 0.3994 - accuracy: 0.8643
978/1875 [==============>...............] - ETA: 5s - loss: 0.3991 - accuracy: 0.8644
986/1875 [==============>...............] - ETA: 5s - loss: 0.3988 - accuracy: 0.8644
989/1875 [==============>...............] - ETA: 53s - loss: 0.3987 - accuracy: 0.8643
993/1875 [==============>...............] - ETA: 52s - loss: 0.3983 - accuracy: 0.8643
1000/1875 [===============>..............] - ETA: 52s - loss: 0.3984 - accuracy: 0.8643
1008/1875 [===============>..............] - ETA: 51s - loss: 0.3977 - accuracy: 0.8645
1016/1875 [===============>..............] - ETA: 50s - loss: 0.3979 - accuracy: 0.8642
1024/1875 [===============>..............] - ETA: 49s - loss: 0.3979 - accuracy: 0.8642
1032/1875 [===============>..............] - ETA: 48s - loss: 0.3977 - accuracy: 0.8642
1039/1875 [===============>..............] - ETA: 48s - loss: 0.3977 - accuracy: 0.8641
1047/1875 [===============>..............] - ETA: 47s - loss: 0.3974 - accuracy: 0.8642
1055/1875 [===============>..............] - ETA: 46s - loss: 0.3977 - accuracy: 0.8640
1063/1875 [================>.............] - ETA: 45s - loss: 0.3983 - accuracy: 0.8638
1071/1875 [================>.............] - ETA: 45s - loss: 0.3980 - accuracy: 0.8637
1079/1875 [================>.............] - ETA: 44s - loss: 0.3977 - accuracy: 0.8638
1087/1875 [================>.............] - ETA: 43s - loss: 0.3981 - accuracy: 0.8638
1095/1875 [================>.............] - ETA: 42s - loss: 0.3977 - accuracy: 0.8638
1103/1875 [================>.............] - ETA: 42s - loss: 0.3982 - accuracy: 0.8635
1111/1875 [================>.............] - ETA: 41s - loss: 0.3997 - accuracy: 0.8634
1119/1875 [================>.............] - ETA: 40s - loss: 0.3994 - accuracy: 0.8635
1127/1875 [=================>............] - ETA: 40s - loss: 0.3994 - accuracy: 0.8635
1135/1875 [=================>............] - ETA: 39s - loss: 0.3995 - accuracy: 0.8634
1143/1875 [=================>............] - ETA: 38s - loss: 0.3991 - accuracy: 0.8636
1151/1875 [=================>............] - ETA: 38s - loss: 0.3994 - accuracy: 0.8635
1159/1875 [=================>............] - ETA: 37s - loss: 0.3986 - accuracy: 0.8636
1167/1875 [=================>............] - ETA: 36s - loss: 0.3984 - accuracy: 0.8636
1175/1875 [=================>............] - ETA: 36s - loss: 0.3980 - accuracy: 0.8636
1183/1875 [=================>............] - ETA: 35s - loss: 0.3975 - accuracy: 0.8637
1191/1875 [==================>...........] - ETA: 35s - loss: 0.3976 - accuracy: 0.8637
1199/1875 [==================>...........] - ETA: 34s - loss: 0.3980 - accuracy: 0.8637
1207/1875 [==================>...........] - ETA: 33s - loss: 0.3979 - accuracy: 0.8636
1215/1875 [==================>...........] - ETA: 33s - loss: 0.3978 - accuracy: 0.8636
1222/1875 [==================>...........] - ETA: 32s - loss: 0.3978 - accuracy: 0.8637
1230/1875 [==================>...........] - ETA: 32s - loss: 0.3978 - accuracy: 0.8638
1237/1875 [==================>...........] - ETA: 31s - loss: 0.3976 - accuracy: 0.8639
1245/1875 [==================>...........] - ETA: 31s - loss: 0.3971 - accuracy: 0.8640
1253/1875 [===================>..........] - ETA: 30s - loss: 0.3968 - accuracy: 0.8643
1261/1875 [===================>..........] - ETA: 29s - loss: 0.3968 - accuracy: 0.8643
1269/1875 [===================>..........] - ETA: 29s - loss: 0.3970 - accuracy: 0.8645
1276/1875 [===================>..........] - ETA: 28s - loss: 0.3978 - accuracy: 0.8643
1284/1875 [===================>..........] - ETA: 28s - loss: 0.3976 - accuracy: 0.8643
1292/1875 [===================>..........] - ETA: 27s - loss: 0.3975 - accuracy: 0.8643
1300/1875 [===================>..........] - ETA: 27s - loss: 0.3974 - accuracy: 0.8645
1308/1875 [===================>..........] - ETA: 26s - loss: 0.3973 - accuracy: 0.8646
1316/1875 [====================>.........] - ETA: 26s - loss: 0.3971 - accuracy: 0.8647
1324/1875 [====================>.........] - ETA: 25s - loss: 0.3969 - accuracy: 0.8647
1332/1875 [====================>.........] - ETA: 25s - loss: 0.3974 - accuracy: 0.8647
1340/1875 [====================>.........] - ETA: 24s - loss: 0.3976 - accuracy: 0.8646
1348/1875 [====================>.........] - ETA: 24s - loss: 0.3975 - accuracy: 0.8645
1356/1875 [====================>.........] - ETA: 23s - loss: 0.3982 - accuracy: 0.8645
1363/1875 [====================>.........] - ETA: 23s - loss: 0.3989 - accuracy: 0.8642
1371/1875 [====================>.........] - ETA: 22s - loss: 0.3987 - accuracy: 0.8643
1379/1875 [=====================>........] - ETA: 22s - loss: 0.3984 - accuracy: 0.8643
1387/1875 [=====================>........] - ETA: 21s - loss: 0.3982 - accuracy: 0.8644
1395/1875 [=====================>........] - ETA: 21s - loss: 0.3984 - accuracy: 0.8643
1403/1875 [=====================>........] - ETA: 21s - loss: 0.3984 - accuracy: 0.8644
1411/1875 [=====================>........] - ETA: 20s - loss: 0.3989 - accuracy: 0.8641
1419/1875 [=====================>........] - ETA: 20s - loss: 0.3987 - accuracy: 0.8640
1427/1875 [=====================>........] - ETA: 19s - loss: 0.3987 - accuracy: 0.8641
1435/1875 [=====================>........] - ETA: 19s - loss: 0.3994 - accuracy: 0.8639
1443/1875 [======================>.......] - ETA: 18s - loss: 0.3996 - accuracy: 0.8639
1451/1875 [======================>.......] - ETA: 18s - loss: 0.4005 - accuracy: 0.8638
1459/1875 [======================>.......] - ETA: 17s - loss: 0.4006 - accuracy: 0.8638
1467/1875 [======================>.......] - ETA: 17s - loss: 0.4013 - accuracy: 0.8635
1475/1875 [======================>.......] - ETA: 17s - loss: 0.4007 - accuracy: 0.8637
1483/1875 [======================>.......] - ETA: 16s - loss: 0.4005 - accuracy: 0.8638
1491/1875 [======================>.......] - ETA: 16s - loss: 0.4001 - accuracy: 0.8641
1499/1875 [======================>.......] - ETA: 15s - loss: 0.4001 - accuracy: 0.8640
1507/1875 [=======================>......] - ETA: 15s - loss: 0.4003 - accuracy: 0.8641
1515/1875 [=======================>......] - ETA: 15s - loss: 0.4002 - accuracy: 0.8642
1523/1875 [=======================>......] - ETA: 14s - loss: 0.4006 - accuracy: 0.8639
1531/1875 [=======================>......] - ETA: 14s - loss: 0.4010 - accuracy: 0.8638
1539/1875 [=======================>......] - ETA: 13s - loss: 0.4012 - accuracy: 0.8637
1547/1875 [=======================>......] - ETA: 13s - loss: 0.4015 - accuracy: 0.8637
1555/1875 [=======================>......] - ETA: 13s - loss: 0.4019 - accuracy: 0.8635
1563/1875 [========================>.....] - ETA: 12s - loss: 0.4015 - accuracy: 0.8636
1571/1875 [========================>.....] - ETA: 12s - loss: 0.4012 - accuracy: 0.8637
1579/1875 [========================>.....] - ETA: 11s - loss: 0.4007 - accuracy: 0.8639
1587/1875 [========================>.....] - ETA: 11s - loss: 0.4006 - accuracy: 0.8639
1595/1875 [========================>.....] - ETA: 11s - loss: 0.4001 - accuracy: 0.8639
1603/1875 [========================>.....] - ETA: 10s - loss: 0.3993 - accuracy: 0.8642
1612/1875 [========================>.....] - ETA: 10s - loss: 0.3991 - accuracy: 0.8642
1620/1875 [========================>.....] - ETA: 10s - loss: 0.3997 - accuracy: 0.8640
1628/1875 [=========================>....] - ETA: 9s - loss: 0.4001 - accuracy: 0.8641
1636/1875 [=========================>....] - ETA: 9s - loss: 0.4000 - accuracy: 0.8641
1644/1875 [=========================>....] - ETA: 8s - loss: 0.3998 - accuracy: 0.8640
1652/1875 [=========================>....] - ETA: 8s - loss: 0.3993 - accuracy: 0.8641
1660/1875 [=========================>....] - ETA: 8s - loss: 0.3991 - accuracy: 0.8643
1668/1875 [=========================>....] - ETA: 7s - loss: 0.3988 - accuracy: 0.8642
1676/1875 [=========================>....] - ETA: 7s - loss: 0.3985 - accuracy: 0.8643
1684/1875 [=========================>....] - ETA: 7s - loss: 0.3987 - accuracy: 0.8641
1692/1875 [==========================>...] - ETA: 6s - loss: 0.3988 - accuracy: 0.8642
1701/1875 [==========================>...] - ETA: 6s - loss: 0.3988 - accuracy: 0.8641
1709/1875 [==========================>...] - ETA: 6s - loss: 0.3987 - accuracy: 0.8641
1717/1875 [==========================>...] - ETA: 5s - loss: 0.3986 - accuracy: 0.8642
1725/1875 [==========================>...] - ETA: 5s - loss: 0.3982 - accuracy: 0.8643
1733/1875 [==========================>...] - ETA: 5s - loss: 0.3985 - accuracy: 0.8643
1741/1875 [==========================>...] - ETA: 4s - loss: 0.3987 - accuracy: 0.8642
1749/1875 [==========================>...] - ETA: 4s - loss: 0.3990 - accuracy: 0.8641
1757/1875 [===========================>..] - ETA: 4s - loss: 0.3988 - accuracy: 0.8642
1765/1875 [===========================>..] - ETA: 4s - loss: 0.3985 - accuracy: 0.8644
1773/1875 [===========================>..] - ETA: 3s - loss: 0.3984 - accuracy: 0.8644
1780/1875 [===========================>..] - ETA: 3s - loss: 0.3980 - accuracy: 0.8645
1788/1875 [===========================>..] - ETA: 3s - loss: 0.3979 - accuracy: 0.8646
1796/1875 [===========================>..] - ETA: 2s - loss: 0.3979 - accuracy: 0.8645
1804/1875 [===========================>..] - ETA: 2s - loss: 0.3975 - accuracy: 0.8644
1812/1875 [===========================>..] - ETA: 2s - loss: 0.3977 - accuracy: 0.8644
1820/1875 [============================>.] - ETA: 1s - loss: 0.3977 - accuracy: 0.8645
1828/1875 [============================>.] - ETA: 1s - loss: 0.3975 - accuracy: 0.8646
1836/1875 [============================>.] - ETA: 1s - loss: 0.3978 - accuracy: 0.8645
1844/1875 [============================>.] - ETA: 1s - loss: 0.3980 - accuracy: 0.8644
1850/1875 [============================>.] - ETA: 0s - loss: 0.3978 - accuracy: 0.8644
1858/1875 [============================>.] - ETA: 0s - loss: 0.3976 - accuracy: 0.8645
1865/1875 [============================>.] - ETA: 0s - loss: 0.3975 - accuracy: 0.8645
1873/1875 [============================>.] - ETA: 0s - loss: 0.3980 - accuracy: 0.8644
1875/1875 [==============================] - 66s 35ms/step - loss: 0.3980 - accuracy: 0.8643
Predictions and evaluations¶
We can now call predict to generate predictions, and evaluate the trained model on the entire test set
network.predict(X_test)
test_loss, test_acc = network.evaluate(X_test, y_test)
np.set_printoptions(precision=7)
fig, axes = plt.subplots(1, 1, figsize=(2, 2))
sample_id = 4
axes.imshow(Xf_test[sample_id].reshape(28, 28), cmap=plt.cm.gray_r)
axes.set_xlabel("True label: {}".format(yf_test[sample_id]))
axes.set_xticks([])
axes.set_yticks([])
print(network.predict(Xf_test)[sample_id])
2022-03-16 12:45:53.112813: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
[0.007654 0.0000028 0.7707164 0.0006719 0.0179828 0. 0.2029032
0. 0.0000689 0. ]
test_loss, test_acc = network.evaluate(Xf_test, yf_test)
print('Test accuracy:', test_acc)
1/313 [..............................] - ETA: 40s - loss: 0.2433 - accuracy: 0.9375
6/313 [..............................] - ETA: 3s - loss: 0.4325 - accuracy: 0.8594
14/313 [>.............................] - ETA: 2s - loss: 0.3654 - accuracy: 0.8772
2022-03-16 12:45:53.782057: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
20/313 [>.............................] - ETA: 2s - loss: 0.3526 - accuracy: 0.8766
28/313 [=>............................] - ETA: 2s - loss: 0.3395 - accuracy: 0.8806
36/313 [==>...........................] - ETA: 2s - loss: 0.3421 - accuracy: 0.8863
44/313 [===>..........................] - ETA: 2s - loss: 0.3404 - accuracy: 0.8828
52/313 [===>..........................] - ETA: 1s - loss: 0.3485 - accuracy: 0.8816
61/313 [====>.........................] - ETA: 1s - loss: 0.3540 - accuracy: 0.8842
69/313 [=====>........................] - ETA: 53:53 - loss: 0.3495 - accuracy: 0.8827
74/313 [======>.......................] - ETA: 49:10 - loss: 0.3427 - accuracy: 0.8856
81/313 [======>.......................] - ETA: 43:33 - loss: 0.3434 - accuracy: 0.8862
89/313 [=======>......................] - ETA: 38:14 - loss: 0.3404 - accuracy: 0.8848
98/313 [========>.....................] - ETA: 33:18 - loss: 0.3447 - accuracy: 0.8839
102/313 [========>.....................] - ETA: 31:23 - loss: 0.3465 - accuracy: 0.8830
110/313 [=========>....................] - ETA: 27:59 - loss: 0.3490 - accuracy: 0.8813
118/313 [==========>...................] - ETA: 25:02 - loss: 0.3547 - accuracy: 0.8827
126/313 [===========>..................] - ETA: 22:28 - loss: 0.3599 - accuracy: 0.8812
134/313 [===========>..................] - ETA: 20:13 - loss: 0.3588 - accuracy: 0.8806
140/313 [============>.................] - ETA: 18:42 - loss: 0.3590 - accuracy: 0.8804
148/313 [=============>................] - ETA: 16:52 - loss: 0.3610 - accuracy: 0.8801
155/313 [=============>................] - ETA: 15:25 - loss: 0.3621 - accuracy: 0.8786
161/313 [==============>...............] - ETA: 14:16 - loss: 0.3613 - accuracy: 0.8781
169/313 [===============>..............] - ETA: 12:53 - loss: 0.3683 - accuracy: 0.8776
178/313 [================>.............] - ETA: 11:28 - loss: 0.3682 - accuracy: 0.8771
187/313 [================>.............] - ETA: 10:11 - loss: 0.3678 - accuracy: 0.8782
196/313 [=================>............] - ETA: 9:01 - loss: 0.3663 - accuracy: 0.8777
205/313 [==================>...........] - ETA: 7:57 - loss: 0.3627 - accuracy: 0.8784
214/313 [===================>..........] - ETA: 6:59 - loss: 0.3641 - accuracy: 0.8789
223/313 [====================>.........] - ETA: 6:05 - loss: 0.3620 - accuracy: 0.8793
232/313 [=====================>........] - ETA: 5:16 - loss: 0.3608 - accuracy: 0.8792
241/313 [======================>.......] - ETA: 4:30 - loss: 0.3646 - accuracy: 0.8784
251/313 [=======================>......] - ETA: 3:43 - loss: 0.3660 - accuracy: 0.8780
260/313 [=======================>......] - ETA: 3:04 - loss: 0.3645 - accuracy: 0.8784
269/313 [========================>.....] - ETA: 2:28 - loss: 0.3646 - accuracy: 0.8780
278/313 [=========================>....] - ETA: 1:54 - loss: 0.3651 - accuracy: 0.8779
287/313 [==========================>...] - ETA: 1:22 - loss: 0.3695 - accuracy: 0.8772
295/313 [===========================>..] - ETA: 55s - loss: 0.3671 - accuracy: 0.8775
304/313 [============================>.] - ETA: 26s - loss: 0.3656 - accuracy: 0.8773
313/313 [==============================] - ETA: 0s - loss: 0.3672 - accuracy: 0.8764
313/313 [==============================] - 903s 3s/step - loss: 0.3672 - accuracy: 0.8764
Test accuracy: 0.8764000535011292
Model selection¶
How many epochs do we need for training?
Train the neural net and track the loss after every iteration on a validation set
You can add a callback to the fit version to get info on every epoch
Best model after a few epochs, then starts overfitting
from tensorflow.keras.callbacks import Callback
from IPython.display import clear_output
# For plotting the learning curve in real time
class TrainingPlot(Callback):
# This function is called when the training begins
def on_train_begin(self, logs={}):
# Initialize the lists for holding the logs, losses and accuracies
self.losses = []
self.acc = []
self.val_losses = []
self.val_acc = []
self.logs = []
self.max_acc = 0
# This function is called at the end of each epoch
def on_epoch_end(self, epoch, logs={}):
# Append the logs, losses and accuracies to the lists
self.logs.append(logs)
self.losses.append(logs.get('loss'))
self.acc.append(logs.get('accuracy'))
self.val_losses.append(logs.get('val_loss'))
self.val_acc.append(logs.get('val_accuracy'))
self.max_acc = max(self.max_acc, logs.get('val_accuracy'))
# Before plotting ensure at least 2 epochs have passed
if len(self.losses) > 1:
# Clear the previous plot
clear_output(wait=True)
N = np.arange(0, len(self.losses))
# Plot train loss, train acc, val loss and val acc against epochs passed
plt.figure(figsize=(8,3))
plt.plot(N, self.losses, lw=2, c="b", linestyle="-", label = "train_loss")
plt.plot(N, self.acc, lw=2, c="r", linestyle="-", label = "train_acc")
plt.plot(N, self.val_losses, lw=2, c="b", linestyle=":", label = "val_loss")
plt.plot(N, self.val_acc, lw=2, c="r", linestyle=":", label = "val_acc")
plt.title("Training Loss and Accuracy [Epoch {}, Max Acc {:.4f}]".format(epoch, self.max_acc))
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.show()
from sklearn.model_selection import train_test_split
x_val, partial_x_train = Xf_train[:10000], Xf_train[10000:]
y_val, partial_y_train = yf_train[:10000], yf_train[10000:]
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=25, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses])
Early stopping¶
Stop training when the validation loss (or validation accuracy) no longer improves
Loss can be bumpy: use a moving average or wait for \(k\) steps without improvement
earlystop = callbacks.EarlyStopping(monitor='val_loss', patience=3)
model.fit(x_train, y_train, epochs=25, batch_size=512, callbacks=[earlystop])
from tensorflow.keras import callbacks
earlystop = callbacks.EarlyStopping(monitor='val_loss', patience=3)
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=25, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop])
Regularization and memorization capacity¶
The number of learnable parameters is called the model capacity
A model with more parameters has a higher memorization capacity
Too high capacity causes overfitting, too low causes underfitting
In the extreme, the training set can be ‘memorized’ in the weights
Smaller models are forced it to learn a compressed representation that generalizes better
Find the sweet spot: e.g. start with few parameters, increase until overfitting stars.
Example: 256 nodes in first layer, 32 nodes in second layer, similar performance
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(32, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
earlystop5 = callbacks.EarlyStopping(monitor='val_loss', patience=5)
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=30, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop5])
Information bottleneck¶
If a layer is too narrow, it will lose information that can never be recovered by subsequent layers
Information bottleneck theory defines a bound on the capacity of the network
Imagine that you need to learn 10 outputs (e.g. classes) and your hidden layer has 2 nodes
This is like trying to learn 10 hyperplanes from a 2-dimensional representation
Example: bottleneck of 2 nodes, no overfitting, much higher training loss
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(2, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
earlystop5 = callbacks.EarlyStopping(monitor='val_loss', patience=5)
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=30, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop5])
Weight regularization (weight decay)¶
As we did many times before, we can also add weight regularization to our loss function
L1 regularization: leads to sparse networks with many weights that are 0
L2 regularization: leads to many very small weights
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001), input_shape=(28 * 28,)))
network.add(layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
from tensorflow.keras import regularizers
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001), input_shape=(28 * 28,)))
network.add(layers.Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
earlystop5 = callbacks.EarlyStopping(monitor='val_loss', patience=5)
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=50, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop5])
Dropout¶
Every iteration, randomly set a number of activations \(a_i\) to 0
Dropout rate : fraction of the outputs that are zeroed-out (e.g. 0.1 - 0.5)
Idea: break up accidental non-significant learned patterns
At test time, nothing is dropped out, but the output values are scaled down by the dropout rate
Balances out that more units are active than during training
fig = plt.figure(figsize=(4, 4))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 1], draw_bias=True, labels=True,
show_activations=True, activation=True)
Dropout layers¶
Dropout is usually implemented as a special layer
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dropout(0.5))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.Dropout(0.5))
network.add(layers.Dense(10, activation='softmax'))
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dropout(0.3))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.Dropout(0.3))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=50, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses])
Batch Normalization¶
We’ve seen that scaling the input is important, but what if layer activations become very large?
Same problems, starting deeper in the network
Batch normalization: normalize the activations of the previous layer within each batch
Within a batch, set the mean activation close to 0 and the standard deviation close to 1
Across badges, use exponential moving average of batch-wise mean and variance
Allows deeper networks less prone to vanishing or exploding gradients
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(256, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(64, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network = models.Sequential()
network.add(layers.Dense(265, activation='relu', input_shape=(28 * 28,)))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(64, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=50, batch_size=512, verbose=0,
validation_data=(x_val, y_val), callbacks=[plot_losses])
Tuning multiple hyperparameters¶
You can wrap Keras models as scikit-learn models and use any tuning technique
Keras also has built-in RandomSearch (and HyperBand and BayesianOptimization - see later)
def make_model(hp):
m.add(Dense(units=hp.Int('units', min_value=32, max_value=512, step=32)))
m.compile(optimizer=Adam(hp.Choice('learning rate', [1e-2, 1e-3, 1e-4])))
return model
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
clf = KerasClassifier(make_model)
grid = GridSearchCV(clf, param_grid=param_grid, cv=3)
from kerastuner.tuners import RandomSearch
tuner = keras.RandomSearch(build_model, max_trials=5)
Summary¶
Neural architectures
Training neural nets
Forward pass: Tensor operations
Backward pass: Backpropagation
Neural network design:
Activation functions
Weight initialization
Optimizers
Neural networks in practice
Model selection
Early stopping
Memorization capacity and information bottleneck
L1/L2 regularization
Dropout
Batch normalization